1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * Copyright (c) 2016 Mellanox Technologies. All rights reserved.
4 * Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
5 */
6
7#include <net/genetlink.h>
8#include <net/sock.h>
9
10#include "devl_internal.h"
11
12#define DEVLINK_NL_FLAG_NEED_PORT		BIT(0)
13#define DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT	BIT(1)
14#define DEVLINK_NL_FLAG_NEED_DEV_LOCK		BIT(2)
15
16static const struct genl_multicast_group devlink_nl_mcgrps[] = {
17	[DEVLINK_MCGRP_CONFIG] = { .name = DEVLINK_GENL_MCGRP_CONFIG_NAME },
18};
19
20struct devlink_nl_sock_priv {
21	struct devlink_obj_desc __rcu *flt;
22	spinlock_t flt_lock; /* Protects flt. */
23};
24
25static void devlink_nl_sock_priv_init(void *priv)
26{
27	struct devlink_nl_sock_priv *sk_priv = priv;
28
29	spin_lock_init(&sk_priv->flt_lock);
30}
31
32static void devlink_nl_sock_priv_destroy(void *priv)
33{
34	struct devlink_nl_sock_priv *sk_priv = priv;
35	struct devlink_obj_desc *flt;
36
37	flt = rcu_dereference_protected(sk_priv->flt, true);
38	kfree_rcu(flt, rcu);
39}
40
41int devlink_nl_notify_filter_set_doit(struct sk_buff *skb,
42				      struct genl_info *info)
43{
44	struct devlink_nl_sock_priv *sk_priv;
45	struct nlattr **attrs = info->attrs;
46	struct devlink_obj_desc *flt;
47	size_t data_offset = 0;
48	size_t data_size = 0;
49	char *pos;
50
51	if (attrs[DEVLINK_ATTR_BUS_NAME])
52		data_size = size_add(data_size,
53				     nla_len(attrs[DEVLINK_ATTR_BUS_NAME]) + 1);
54	if (attrs[DEVLINK_ATTR_DEV_NAME])
55		data_size = size_add(data_size,
56				     nla_len(attrs[DEVLINK_ATTR_DEV_NAME]) + 1);
57
58	flt = kzalloc(size_add(sizeof(*flt), data_size), GFP_KERNEL);
59	if (!flt)
60		return -ENOMEM;
61
62	pos = (char *) flt->data;
63	if (attrs[DEVLINK_ATTR_BUS_NAME]) {
64		data_offset += nla_strscpy(pos,
65					   attrs[DEVLINK_ATTR_BUS_NAME],
66					   data_size) + 1;
67		flt->bus_name = pos;
68		pos += data_offset;
69	}
70	if (attrs[DEVLINK_ATTR_DEV_NAME]) {
71		nla_strscpy(pos, attrs[DEVLINK_ATTR_DEV_NAME],
72			    data_size - data_offset);
73		flt->dev_name = pos;
74	}
75
76	if (attrs[DEVLINK_ATTR_PORT_INDEX]) {
77		flt->port_index = nla_get_u32(attrs[DEVLINK_ATTR_PORT_INDEX]);
78		flt->port_index_valid = true;
79	}
80
81	/* Don't attach empty filter. */
82	if (!flt->bus_name && !flt->dev_name && !flt->port_index_valid) {
83		kfree(flt);
84		flt = NULL;
85	}
86
87	sk_priv = genl_sk_priv_get(&devlink_nl_family, NETLINK_CB(skb).sk);
88	if (IS_ERR(sk_priv)) {
89		kfree(flt);
90		return PTR_ERR(sk_priv);
91	}
92	spin_lock(&sk_priv->flt_lock);
93	flt = rcu_replace_pointer(sk_priv->flt, flt,
94				  lockdep_is_held(&sk_priv->flt_lock));
95	spin_unlock(&sk_priv->flt_lock);
96	kfree_rcu(flt, rcu);
97	return 0;
98}
99
100static bool devlink_obj_desc_match(const struct devlink_obj_desc *desc,
101				   const struct devlink_obj_desc *flt)
102{
103	if (desc->bus_name && flt->bus_name &&
104	    strcmp(desc->bus_name, flt->bus_name))
105		return false;
106	if (desc->dev_name && flt->dev_name &&
107	    strcmp(desc->dev_name, flt->dev_name))
108		return false;
109	if (desc->port_index_valid && flt->port_index_valid &&
110	    desc->port_index != flt->port_index)
111		return false;
112	return true;
113}
114
115int devlink_nl_notify_filter(struct sock *dsk, struct sk_buff *skb, void *data)
116{
117	struct devlink_obj_desc *desc = data;
118	struct devlink_nl_sock_priv *sk_priv;
119	struct devlink_obj_desc *flt;
120	int ret = 0;
121
122	rcu_read_lock();
123	sk_priv = __genl_sk_priv_get(&devlink_nl_family, dsk);
124	if (!IS_ERR_OR_NULL(sk_priv)) {
125		flt = rcu_dereference(sk_priv->flt);
126		if (flt)
127			ret = !devlink_obj_desc_match(desc, flt);
128	}
129	rcu_read_unlock();
130	return ret;
131}
132
133int devlink_nl_put_nested_handle(struct sk_buff *msg, struct net *net,
134				 struct devlink *devlink, int attrtype)
135{
136	struct nlattr *nested_attr;
137	struct net *devl_net;
138
139	nested_attr = nla_nest_start(msg, attrtype);
140	if (!nested_attr)
141		return -EMSGSIZE;
142	if (devlink_nl_put_handle(msg, devlink))
143		goto nla_put_failure;
144
145	rcu_read_lock();
146	devl_net = read_pnet_rcu(&devlink->_net);
147	if (!net_eq(net, devl_net)) {
148		int id = peernet2id_alloc(net, devl_net, GFP_ATOMIC);
149
150		rcu_read_unlock();
151		if (nla_put_s32(msg, DEVLINK_ATTR_NETNS_ID, id))
152			return -EMSGSIZE;
153	} else {
154		rcu_read_unlock();
155	}
156
157	nla_nest_end(msg, nested_attr);
158	return 0;
159
160nla_put_failure:
161	nla_nest_cancel(msg, nested_attr);
162	return -EMSGSIZE;
163}
164
165int devlink_nl_msg_reply_and_new(struct sk_buff **msg, struct genl_info *info)
166{
167	int err;
168
169	if (*msg) {
170		err = genlmsg_reply(*msg, info);
171		if (err)
172			return err;
173	}
174	*msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
175	if (!*msg)
176		return -ENOMEM;
177	return 0;
178}
179
180struct devlink *
181devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs,
182			    bool dev_lock)
183{
184	struct devlink *devlink;
185	unsigned long index;
186	char *busname;
187	char *devname;
188
189	if (!attrs[DEVLINK_ATTR_BUS_NAME] || !attrs[DEVLINK_ATTR_DEV_NAME])
190		return ERR_PTR(-EINVAL);
191
192	busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
193	devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
194
195	devlinks_xa_for_each_registered_get(net, index, devlink) {
196		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
197		    strcmp(dev_name(devlink->dev), devname) == 0) {
198			devl_dev_lock(devlink, dev_lock);
199			if (devl_is_registered(devlink))
200				return devlink;
201			devl_dev_unlock(devlink, dev_lock);
202		}
203		devlink_put(devlink);
204	}
205
206	return ERR_PTR(-ENODEV);
207}
208
209static int __devlink_nl_pre_doit(struct sk_buff *skb, struct genl_info *info,
210				 u8 flags)
211{
212	bool dev_lock = flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK;
213	struct devlink_port *devlink_port;
214	struct devlink *devlink;
215	int err;
216
217	devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs,
218					      dev_lock);
219	if (IS_ERR(devlink))
220		return PTR_ERR(devlink);
221
222	info->user_ptr[0] = devlink;
223	if (flags & DEVLINK_NL_FLAG_NEED_PORT) {
224		devlink_port = devlink_port_get_from_info(devlink, info);
225		if (IS_ERR(devlink_port)) {
226			err = PTR_ERR(devlink_port);
227			goto unlock;
228		}
229		info->user_ptr[1] = devlink_port;
230	} else if (flags & DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT) {
231		devlink_port = devlink_port_get_from_info(devlink, info);
232		if (!IS_ERR(devlink_port))
233			info->user_ptr[1] = devlink_port;
234	}
235	return 0;
236
237unlock:
238	devl_dev_unlock(devlink, dev_lock);
239	devlink_put(devlink);
240	return err;
241}
242
243int devlink_nl_pre_doit(const struct genl_split_ops *ops,
244			struct sk_buff *skb, struct genl_info *info)
245{
246	return __devlink_nl_pre_doit(skb, info, 0);
247}
248
249int devlink_nl_pre_doit_port(const struct genl_split_ops *ops,
250			     struct sk_buff *skb, struct genl_info *info)
251{
252	return __devlink_nl_pre_doit(skb, info, DEVLINK_NL_FLAG_NEED_PORT);
253}
254
255int devlink_nl_pre_doit_dev_lock(const struct genl_split_ops *ops,
256				 struct sk_buff *skb, struct genl_info *info)
257{
258	return __devlink_nl_pre_doit(skb, info, DEVLINK_NL_FLAG_NEED_DEV_LOCK);
259}
260
261int devlink_nl_pre_doit_port_optional(const struct genl_split_ops *ops,
262				      struct sk_buff *skb,
263				      struct genl_info *info)
264{
265	return __devlink_nl_pre_doit(skb, info, DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT);
266}
267
268static void __devlink_nl_post_doit(struct sk_buff *skb, struct genl_info *info,
269				   u8 flags)
270{
271	bool dev_lock = flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK;
272	struct devlink *devlink;
273
274	devlink = info->user_ptr[0];
275	devl_dev_unlock(devlink, dev_lock);
276	devlink_put(devlink);
277}
278
279void devlink_nl_post_doit(const struct genl_split_ops *ops,
280			  struct sk_buff *skb, struct genl_info *info)
281{
282	__devlink_nl_post_doit(skb, info, 0);
283}
284
285void
286devlink_nl_post_doit_dev_lock(const struct genl_split_ops *ops,
287			      struct sk_buff *skb, struct genl_info *info)
288{
289	__devlink_nl_post_doit(skb, info, DEVLINK_NL_FLAG_NEED_DEV_LOCK);
290}
291
292static int devlink_nl_inst_single_dumpit(struct sk_buff *msg,
293					 struct netlink_callback *cb, int flags,
294					 devlink_nl_dump_one_func_t *dump_one,
295					 struct nlattr **attrs)
296{
297	struct devlink *devlink;
298	int err;
299
300	devlink = devlink_get_from_attrs_lock(sock_net(msg->sk), attrs, false);
301	if (IS_ERR(devlink))
302		return PTR_ERR(devlink);
303	err = dump_one(msg, devlink, cb, flags | NLM_F_DUMP_FILTERED);
304
305	devl_unlock(devlink);
306	devlink_put(devlink);
307
308	if (err != -EMSGSIZE)
309		return err;
310	return msg->len;
311}
312
313static int devlink_nl_inst_iter_dumpit(struct sk_buff *msg,
314				       struct netlink_callback *cb, int flags,
315				       devlink_nl_dump_one_func_t *dump_one)
316{
317	struct devlink_nl_dump_state *state = devlink_dump_state(cb);
318	struct devlink *devlink;
319	int err = 0;
320
321	while ((devlink = devlinks_xa_find_get(sock_net(msg->sk),
322					       &state->instance))) {
323		devl_lock(devlink);
324
325		if (devl_is_registered(devlink))
326			err = dump_one(msg, devlink, cb, flags);
327		else
328			err = 0;
329
330		devl_unlock(devlink);
331		devlink_put(devlink);
332
333		if (err)
334			break;
335
336		state->instance++;
337
338		/* restart sub-object walk for the next instance */
339		state->idx = 0;
340	}
341
342	if (err != -EMSGSIZE)
343		return err;
344	return msg->len;
345}
346
347int devlink_nl_dumpit(struct sk_buff *msg, struct netlink_callback *cb,
348		      devlink_nl_dump_one_func_t *dump_one)
349{
350	const struct genl_info *info = genl_info_dump(cb);
351	struct nlattr **attrs = info->attrs;
352	int flags = NLM_F_MULTI;
353
354	if (attrs &&
355	    (attrs[DEVLINK_ATTR_BUS_NAME] || attrs[DEVLINK_ATTR_DEV_NAME]))
356		return devlink_nl_inst_single_dumpit(msg, cb, flags, dump_one,
357						     attrs);
358	else
359		return devlink_nl_inst_iter_dumpit(msg, cb, flags, dump_one);
360}
361
362struct genl_family devlink_nl_family __ro_after_init = {
363	.name		= DEVLINK_GENL_NAME,
364	.version	= DEVLINK_GENL_VERSION,
365	.netnsok	= true,
366	.parallel_ops	= true,
367	.module		= THIS_MODULE,
368	.split_ops	= devlink_nl_ops,
369	.n_split_ops	= ARRAY_SIZE(devlink_nl_ops),
370	.resv_start_op	= DEVLINK_CMD_SELFTESTS_RUN + 1,
371	.mcgrps		= devlink_nl_mcgrps,
372	.n_mcgrps	= ARRAY_SIZE(devlink_nl_mcgrps),
373	.sock_priv_size		= sizeof(struct devlink_nl_sock_priv),
374	.sock_priv_init		= devlink_nl_sock_priv_init,
375	.sock_priv_destroy	= devlink_nl_sock_priv_destroy,
376};
377