1// SPDX-License-Identifier: GPL-2.0-only
2/* Copyright (c) 2023 Isovalent */
3
4#include <linux/netdevice.h>
5#include <linux/ethtool.h>
6#include <linux/etherdevice.h>
7#include <linux/filter.h>
8#include <linux/netfilter_netdev.h>
9#include <linux/bpf_mprog.h>
10#include <linux/indirect_call_wrapper.h>
11
12#include <net/netkit.h>
13#include <net/dst.h>
14#include <net/tcx.h>
15
16#define DRV_NAME "netkit"
17
18struct netkit {
19	/* Needed in fast-path */
20	struct net_device __rcu *peer;
21	struct bpf_mprog_entry __rcu *active;
22	enum netkit_action policy;
23	struct bpf_mprog_bundle	bundle;
24
25	/* Needed in slow-path */
26	enum netkit_mode mode;
27	bool primary;
28	u32 headroom;
29};
30
31struct netkit_link {
32	struct bpf_link link;
33	struct net_device *dev;
34	u32 location;
35};
36
37static __always_inline int
38netkit_run(const struct bpf_mprog_entry *entry, struct sk_buff *skb,
39	   enum netkit_action ret)
40{
41	const struct bpf_mprog_fp *fp;
42	const struct bpf_prog *prog;
43
44	bpf_mprog_foreach_prog(entry, fp, prog) {
45		bpf_compute_data_pointers(skb);
46		ret = bpf_prog_run(prog, skb);
47		if (ret != NETKIT_NEXT)
48			break;
49	}
50	return ret;
51}
52
53static void netkit_prep_forward(struct sk_buff *skb, bool xnet)
54{
55	skb_scrub_packet(skb, xnet);
56	skb->priority = 0;
57	nf_skip_egress(skb, true);
58}
59
60static struct netkit *netkit_priv(const struct net_device *dev)
61{
62	return netdev_priv(dev);
63}
64
65static netdev_tx_t netkit_xmit(struct sk_buff *skb, struct net_device *dev)
66{
67	struct netkit *nk = netkit_priv(dev);
68	enum netkit_action ret = READ_ONCE(nk->policy);
69	netdev_tx_t ret_dev = NET_XMIT_SUCCESS;
70	const struct bpf_mprog_entry *entry;
71	struct net_device *peer;
72	int len = skb->len;
73
74	rcu_read_lock();
75	peer = rcu_dereference(nk->peer);
76	if (unlikely(!peer || !(peer->flags & IFF_UP) ||
77		     !pskb_may_pull(skb, ETH_HLEN) ||
78		     skb_orphan_frags(skb, GFP_ATOMIC)))
79		goto drop;
80	netkit_prep_forward(skb, !net_eq(dev_net(dev), dev_net(peer)));
81	skb->dev = peer;
82	entry = rcu_dereference(nk->active);
83	if (entry)
84		ret = netkit_run(entry, skb, ret);
85	switch (ret) {
86	case NETKIT_NEXT:
87	case NETKIT_PASS:
88		skb->protocol = eth_type_trans(skb, skb->dev);
89		skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN);
90		if (likely(__netif_rx(skb) == NET_RX_SUCCESS)) {
91			dev_sw_netstats_tx_add(dev, 1, len);
92			dev_sw_netstats_rx_add(peer, len);
93		} else {
94			goto drop_stats;
95		}
96		break;
97	case NETKIT_REDIRECT:
98		dev_sw_netstats_tx_add(dev, 1, len);
99		skb_do_redirect(skb);
100		break;
101	case NETKIT_DROP:
102	default:
103drop:
104		kfree_skb(skb);
105drop_stats:
106		dev_core_stats_tx_dropped_inc(dev);
107		ret_dev = NET_XMIT_DROP;
108		break;
109	}
110	rcu_read_unlock();
111	return ret_dev;
112}
113
114static int netkit_open(struct net_device *dev)
115{
116	struct netkit *nk = netkit_priv(dev);
117	struct net_device *peer = rtnl_dereference(nk->peer);
118
119	if (!peer)
120		return -ENOTCONN;
121	if (peer->flags & IFF_UP) {
122		netif_carrier_on(dev);
123		netif_carrier_on(peer);
124	}
125	return 0;
126}
127
128static int netkit_close(struct net_device *dev)
129{
130	struct netkit *nk = netkit_priv(dev);
131	struct net_device *peer = rtnl_dereference(nk->peer);
132
133	netif_carrier_off(dev);
134	if (peer)
135		netif_carrier_off(peer);
136	return 0;
137}
138
139static int netkit_get_iflink(const struct net_device *dev)
140{
141	struct netkit *nk = netkit_priv(dev);
142	struct net_device *peer;
143	int iflink = 0;
144
145	rcu_read_lock();
146	peer = rcu_dereference(nk->peer);
147	if (peer)
148		iflink = READ_ONCE(peer->ifindex);
149	rcu_read_unlock();
150	return iflink;
151}
152
153static void netkit_set_multicast(struct net_device *dev)
154{
155	/* Nothing to do, we receive whatever gets pushed to us! */
156}
157
158static void netkit_set_headroom(struct net_device *dev, int headroom)
159{
160	struct netkit *nk = netkit_priv(dev), *nk2;
161	struct net_device *peer;
162
163	if (headroom < 0)
164		headroom = NET_SKB_PAD;
165
166	rcu_read_lock();
167	peer = rcu_dereference(nk->peer);
168	if (unlikely(!peer))
169		goto out;
170
171	nk2 = netkit_priv(peer);
172	nk->headroom = headroom;
173	headroom = max(nk->headroom, nk2->headroom);
174
175	peer->needed_headroom = headroom;
176	dev->needed_headroom = headroom;
177out:
178	rcu_read_unlock();
179}
180
181INDIRECT_CALLABLE_SCOPE struct net_device *netkit_peer_dev(struct net_device *dev)
182{
183	return rcu_dereference(netkit_priv(dev)->peer);
184}
185
186static void netkit_get_stats(struct net_device *dev,
187			     struct rtnl_link_stats64 *stats)
188{
189	dev_fetch_sw_netstats(stats, dev->tstats);
190	stats->tx_dropped = DEV_STATS_READ(dev, tx_dropped);
191}
192
193static void netkit_uninit(struct net_device *dev);
194
195static const struct net_device_ops netkit_netdev_ops = {
196	.ndo_open		= netkit_open,
197	.ndo_stop		= netkit_close,
198	.ndo_start_xmit		= netkit_xmit,
199	.ndo_set_rx_mode	= netkit_set_multicast,
200	.ndo_set_rx_headroom	= netkit_set_headroom,
201	.ndo_get_iflink		= netkit_get_iflink,
202	.ndo_get_peer_dev	= netkit_peer_dev,
203	.ndo_get_stats64	= netkit_get_stats,
204	.ndo_uninit		= netkit_uninit,
205	.ndo_features_check	= passthru_features_check,
206};
207
208static void netkit_get_drvinfo(struct net_device *dev,
209			       struct ethtool_drvinfo *info)
210{
211	strscpy(info->driver, DRV_NAME, sizeof(info->driver));
212}
213
214static const struct ethtool_ops netkit_ethtool_ops = {
215	.get_drvinfo		= netkit_get_drvinfo,
216};
217
218static void netkit_setup(struct net_device *dev)
219{
220	static const netdev_features_t netkit_features_hw_vlan =
221		NETIF_F_HW_VLAN_CTAG_TX |
222		NETIF_F_HW_VLAN_CTAG_RX |
223		NETIF_F_HW_VLAN_STAG_TX |
224		NETIF_F_HW_VLAN_STAG_RX;
225	static const netdev_features_t netkit_features =
226		netkit_features_hw_vlan |
227		NETIF_F_SG |
228		NETIF_F_FRAGLIST |
229		NETIF_F_HW_CSUM |
230		NETIF_F_RXCSUM |
231		NETIF_F_SCTP_CRC |
232		NETIF_F_HIGHDMA |
233		NETIF_F_GSO_SOFTWARE |
234		NETIF_F_GSO_ENCAP_ALL;
235
236	ether_setup(dev);
237	dev->max_mtu = ETH_MAX_MTU;
238	dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS;
239
240	dev->flags |= IFF_NOARP;
241	dev->priv_flags &= ~IFF_TX_SKB_SHARING;
242	dev->priv_flags |= IFF_LIVE_ADDR_CHANGE;
243	dev->priv_flags |= IFF_PHONY_HEADROOM;
244	dev->priv_flags |= IFF_NO_QUEUE;
245
246	dev->ethtool_ops = &netkit_ethtool_ops;
247	dev->netdev_ops  = &netkit_netdev_ops;
248
249	dev->features |= netkit_features | NETIF_F_LLTX;
250	dev->hw_features = netkit_features;
251	dev->hw_enc_features = netkit_features;
252	dev->mpls_features = NETIF_F_HW_CSUM | NETIF_F_GSO_SOFTWARE;
253	dev->vlan_features = dev->features & ~netkit_features_hw_vlan;
254
255	dev->needs_free_netdev = true;
256
257	netif_set_tso_max_size(dev, GSO_MAX_SIZE);
258}
259
260static struct net *netkit_get_link_net(const struct net_device *dev)
261{
262	struct netkit *nk = netkit_priv(dev);
263	struct net_device *peer = rtnl_dereference(nk->peer);
264
265	return peer ? dev_net(peer) : dev_net(dev);
266}
267
268static int netkit_check_policy(int policy, struct nlattr *tb,
269			       struct netlink_ext_ack *extack)
270{
271	switch (policy) {
272	case NETKIT_PASS:
273	case NETKIT_DROP:
274		return 0;
275	default:
276		NL_SET_ERR_MSG_ATTR(extack, tb,
277				    "Provided default xmit policy not supported");
278		return -EINVAL;
279	}
280}
281
282static int netkit_check_mode(int mode, struct nlattr *tb,
283			     struct netlink_ext_ack *extack)
284{
285	switch (mode) {
286	case NETKIT_L2:
287	case NETKIT_L3:
288		return 0;
289	default:
290		NL_SET_ERR_MSG_ATTR(extack, tb,
291				    "Provided device mode can only be L2 or L3");
292		return -EINVAL;
293	}
294}
295
296static int netkit_validate(struct nlattr *tb[], struct nlattr *data[],
297			   struct netlink_ext_ack *extack)
298{
299	struct nlattr *attr = tb[IFLA_ADDRESS];
300
301	if (!attr)
302		return 0;
303	NL_SET_ERR_MSG_ATTR(extack, attr,
304			    "Setting Ethernet address is not supported");
305	return -EOPNOTSUPP;
306}
307
308static struct rtnl_link_ops netkit_link_ops;
309
310static int netkit_new_link(struct net *src_net, struct net_device *dev,
311			   struct nlattr *tb[], struct nlattr *data[],
312			   struct netlink_ext_ack *extack)
313{
314	struct nlattr *peer_tb[IFLA_MAX + 1], **tbp = tb, *attr;
315	enum netkit_action default_prim = NETKIT_PASS;
316	enum netkit_action default_peer = NETKIT_PASS;
317	enum netkit_mode mode = NETKIT_L3;
318	unsigned char ifname_assign_type;
319	struct ifinfomsg *ifmp = NULL;
320	struct net_device *peer;
321	char ifname[IFNAMSIZ];
322	struct netkit *nk;
323	struct net *net;
324	int err;
325
326	if (data) {
327		if (data[IFLA_NETKIT_MODE]) {
328			attr = data[IFLA_NETKIT_MODE];
329			mode = nla_get_u32(attr);
330			err = netkit_check_mode(mode, attr, extack);
331			if (err < 0)
332				return err;
333		}
334		if (data[IFLA_NETKIT_PEER_INFO]) {
335			attr = data[IFLA_NETKIT_PEER_INFO];
336			ifmp = nla_data(attr);
337			err = rtnl_nla_parse_ifinfomsg(peer_tb, attr, extack);
338			if (err < 0)
339				return err;
340			err = netkit_validate(peer_tb, NULL, extack);
341			if (err < 0)
342				return err;
343			tbp = peer_tb;
344		}
345		if (data[IFLA_NETKIT_POLICY]) {
346			attr = data[IFLA_NETKIT_POLICY];
347			default_prim = nla_get_u32(attr);
348			err = netkit_check_policy(default_prim, attr, extack);
349			if (err < 0)
350				return err;
351		}
352		if (data[IFLA_NETKIT_PEER_POLICY]) {
353			attr = data[IFLA_NETKIT_PEER_POLICY];
354			default_peer = nla_get_u32(attr);
355			err = netkit_check_policy(default_peer, attr, extack);
356			if (err < 0)
357				return err;
358		}
359	}
360
361	if (ifmp && tbp[IFLA_IFNAME]) {
362		nla_strscpy(ifname, tbp[IFLA_IFNAME], IFNAMSIZ);
363		ifname_assign_type = NET_NAME_USER;
364	} else {
365		strscpy(ifname, "nk%d", IFNAMSIZ);
366		ifname_assign_type = NET_NAME_ENUM;
367	}
368
369	net = rtnl_link_get_net(src_net, tbp);
370	if (IS_ERR(net))
371		return PTR_ERR(net);
372
373	peer = rtnl_create_link(net, ifname, ifname_assign_type,
374				&netkit_link_ops, tbp, extack);
375	if (IS_ERR(peer)) {
376		put_net(net);
377		return PTR_ERR(peer);
378	}
379
380	netif_inherit_tso_max(peer, dev);
381
382	if (mode == NETKIT_L2)
383		eth_hw_addr_random(peer);
384	if (ifmp && dev->ifindex)
385		peer->ifindex = ifmp->ifi_index;
386
387	nk = netkit_priv(peer);
388	nk->primary = false;
389	nk->policy = default_peer;
390	nk->mode = mode;
391	bpf_mprog_bundle_init(&nk->bundle);
392
393	err = register_netdevice(peer);
394	put_net(net);
395	if (err < 0)
396		goto err_register_peer;
397	netif_carrier_off(peer);
398	if (mode == NETKIT_L2)
399		dev_change_flags(peer, peer->flags & ~IFF_NOARP, NULL);
400
401	err = rtnl_configure_link(peer, NULL, 0, NULL);
402	if (err < 0)
403		goto err_configure_peer;
404
405	if (mode == NETKIT_L2)
406		eth_hw_addr_random(dev);
407	if (tb[IFLA_IFNAME])
408		nla_strscpy(dev->name, tb[IFLA_IFNAME], IFNAMSIZ);
409	else
410		strscpy(dev->name, "nk%d", IFNAMSIZ);
411
412	nk = netkit_priv(dev);
413	nk->primary = true;
414	nk->policy = default_prim;
415	nk->mode = mode;
416	bpf_mprog_bundle_init(&nk->bundle);
417
418	err = register_netdevice(dev);
419	if (err < 0)
420		goto err_configure_peer;
421	netif_carrier_off(dev);
422	if (mode == NETKIT_L2)
423		dev_change_flags(dev, dev->flags & ~IFF_NOARP, NULL);
424
425	rcu_assign_pointer(netkit_priv(dev)->peer, peer);
426	rcu_assign_pointer(netkit_priv(peer)->peer, dev);
427	return 0;
428err_configure_peer:
429	unregister_netdevice(peer);
430	return err;
431err_register_peer:
432	free_netdev(peer);
433	return err;
434}
435
436static struct bpf_mprog_entry *netkit_entry_fetch(struct net_device *dev,
437						  bool bundle_fallback)
438{
439	struct netkit *nk = netkit_priv(dev);
440	struct bpf_mprog_entry *entry;
441
442	ASSERT_RTNL();
443	entry = rcu_dereference_rtnl(nk->active);
444	if (entry)
445		return entry;
446	if (bundle_fallback)
447		return &nk->bundle.a;
448	return NULL;
449}
450
451static void netkit_entry_update(struct net_device *dev,
452				struct bpf_mprog_entry *entry)
453{
454	struct netkit *nk = netkit_priv(dev);
455
456	ASSERT_RTNL();
457	rcu_assign_pointer(nk->active, entry);
458}
459
460static void netkit_entry_sync(void)
461{
462	synchronize_rcu();
463}
464
465static struct net_device *netkit_dev_fetch(struct net *net, u32 ifindex, u32 which)
466{
467	struct net_device *dev;
468	struct netkit *nk;
469
470	ASSERT_RTNL();
471
472	switch (which) {
473	case BPF_NETKIT_PRIMARY:
474	case BPF_NETKIT_PEER:
475		break;
476	default:
477		return ERR_PTR(-EINVAL);
478	}
479
480	dev = __dev_get_by_index(net, ifindex);
481	if (!dev)
482		return ERR_PTR(-ENODEV);
483	if (dev->netdev_ops != &netkit_netdev_ops)
484		return ERR_PTR(-ENXIO);
485
486	nk = netkit_priv(dev);
487	if (!nk->primary)
488		return ERR_PTR(-EACCES);
489	if (which == BPF_NETKIT_PEER) {
490		dev = rcu_dereference_rtnl(nk->peer);
491		if (!dev)
492			return ERR_PTR(-ENODEV);
493	}
494	return dev;
495}
496
497int netkit_prog_attach(const union bpf_attr *attr, struct bpf_prog *prog)
498{
499	struct bpf_mprog_entry *entry, *entry_new;
500	struct bpf_prog *replace_prog = NULL;
501	struct net_device *dev;
502	int ret;
503
504	rtnl_lock();
505	dev = netkit_dev_fetch(current->nsproxy->net_ns, attr->target_ifindex,
506			       attr->attach_type);
507	if (IS_ERR(dev)) {
508		ret = PTR_ERR(dev);
509		goto out;
510	}
511	entry = netkit_entry_fetch(dev, true);
512	if (attr->attach_flags & BPF_F_REPLACE) {
513		replace_prog = bpf_prog_get_type(attr->replace_bpf_fd,
514						 prog->type);
515		if (IS_ERR(replace_prog)) {
516			ret = PTR_ERR(replace_prog);
517			replace_prog = NULL;
518			goto out;
519		}
520	}
521	ret = bpf_mprog_attach(entry, &entry_new, prog, NULL, replace_prog,
522			       attr->attach_flags, attr->relative_fd,
523			       attr->expected_revision);
524	if (!ret) {
525		if (entry != entry_new) {
526			netkit_entry_update(dev, entry_new);
527			netkit_entry_sync();
528		}
529		bpf_mprog_commit(entry);
530	}
531out:
532	if (replace_prog)
533		bpf_prog_put(replace_prog);
534	rtnl_unlock();
535	return ret;
536}
537
538int netkit_prog_detach(const union bpf_attr *attr, struct bpf_prog *prog)
539{
540	struct bpf_mprog_entry *entry, *entry_new;
541	struct net_device *dev;
542	int ret;
543
544	rtnl_lock();
545	dev = netkit_dev_fetch(current->nsproxy->net_ns, attr->target_ifindex,
546			       attr->attach_type);
547	if (IS_ERR(dev)) {
548		ret = PTR_ERR(dev);
549		goto out;
550	}
551	entry = netkit_entry_fetch(dev, false);
552	if (!entry) {
553		ret = -ENOENT;
554		goto out;
555	}
556	ret = bpf_mprog_detach(entry, &entry_new, prog, NULL, attr->attach_flags,
557			       attr->relative_fd, attr->expected_revision);
558	if (!ret) {
559		if (!bpf_mprog_total(entry_new))
560			entry_new = NULL;
561		netkit_entry_update(dev, entry_new);
562		netkit_entry_sync();
563		bpf_mprog_commit(entry);
564	}
565out:
566	rtnl_unlock();
567	return ret;
568}
569
570int netkit_prog_query(const union bpf_attr *attr, union bpf_attr __user *uattr)
571{
572	struct net_device *dev;
573	int ret;
574
575	rtnl_lock();
576	dev = netkit_dev_fetch(current->nsproxy->net_ns,
577			       attr->query.target_ifindex,
578			       attr->query.attach_type);
579	if (IS_ERR(dev)) {
580		ret = PTR_ERR(dev);
581		goto out;
582	}
583	ret = bpf_mprog_query(attr, uattr, netkit_entry_fetch(dev, false));
584out:
585	rtnl_unlock();
586	return ret;
587}
588
589static struct netkit_link *netkit_link(const struct bpf_link *link)
590{
591	return container_of(link, struct netkit_link, link);
592}
593
594static int netkit_link_prog_attach(struct bpf_link *link, u32 flags,
595				   u32 id_or_fd, u64 revision)
596{
597	struct netkit_link *nkl = netkit_link(link);
598	struct bpf_mprog_entry *entry, *entry_new;
599	struct net_device *dev = nkl->dev;
600	int ret;
601
602	ASSERT_RTNL();
603	entry = netkit_entry_fetch(dev, true);
604	ret = bpf_mprog_attach(entry, &entry_new, link->prog, link, NULL, flags,
605			       id_or_fd, revision);
606	if (!ret) {
607		if (entry != entry_new) {
608			netkit_entry_update(dev, entry_new);
609			netkit_entry_sync();
610		}
611		bpf_mprog_commit(entry);
612	}
613	return ret;
614}
615
616static void netkit_link_release(struct bpf_link *link)
617{
618	struct netkit_link *nkl = netkit_link(link);
619	struct bpf_mprog_entry *entry, *entry_new;
620	struct net_device *dev;
621	int ret = 0;
622
623	rtnl_lock();
624	dev = nkl->dev;
625	if (!dev)
626		goto out;
627	entry = netkit_entry_fetch(dev, false);
628	if (!entry) {
629		ret = -ENOENT;
630		goto out;
631	}
632	ret = bpf_mprog_detach(entry, &entry_new, link->prog, link, 0, 0, 0);
633	if (!ret) {
634		if (!bpf_mprog_total(entry_new))
635			entry_new = NULL;
636		netkit_entry_update(dev, entry_new);
637		netkit_entry_sync();
638		bpf_mprog_commit(entry);
639		nkl->dev = NULL;
640	}
641out:
642	WARN_ON_ONCE(ret);
643	rtnl_unlock();
644}
645
646static int netkit_link_update(struct bpf_link *link, struct bpf_prog *nprog,
647			      struct bpf_prog *oprog)
648{
649	struct netkit_link *nkl = netkit_link(link);
650	struct bpf_mprog_entry *entry, *entry_new;
651	struct net_device *dev;
652	int ret = 0;
653
654	rtnl_lock();
655	dev = nkl->dev;
656	if (!dev) {
657		ret = -ENOLINK;
658		goto out;
659	}
660	if (oprog && link->prog != oprog) {
661		ret = -EPERM;
662		goto out;
663	}
664	oprog = link->prog;
665	if (oprog == nprog) {
666		bpf_prog_put(nprog);
667		goto out;
668	}
669	entry = netkit_entry_fetch(dev, false);
670	if (!entry) {
671		ret = -ENOENT;
672		goto out;
673	}
674	ret = bpf_mprog_attach(entry, &entry_new, nprog, link, oprog,
675			       BPF_F_REPLACE | BPF_F_ID,
676			       link->prog->aux->id, 0);
677	if (!ret) {
678		WARN_ON_ONCE(entry != entry_new);
679		oprog = xchg(&link->prog, nprog);
680		bpf_prog_put(oprog);
681		bpf_mprog_commit(entry);
682	}
683out:
684	rtnl_unlock();
685	return ret;
686}
687
688static void netkit_link_dealloc(struct bpf_link *link)
689{
690	kfree(netkit_link(link));
691}
692
693static void netkit_link_fdinfo(const struct bpf_link *link, struct seq_file *seq)
694{
695	const struct netkit_link *nkl = netkit_link(link);
696	u32 ifindex = 0;
697
698	rtnl_lock();
699	if (nkl->dev)
700		ifindex = nkl->dev->ifindex;
701	rtnl_unlock();
702
703	seq_printf(seq, "ifindex:\t%u\n", ifindex);
704	seq_printf(seq, "attach_type:\t%u (%s)\n",
705		   nkl->location,
706		   nkl->location == BPF_NETKIT_PRIMARY ? "primary" : "peer");
707}
708
709static int netkit_link_fill_info(const struct bpf_link *link,
710				 struct bpf_link_info *info)
711{
712	const struct netkit_link *nkl = netkit_link(link);
713	u32 ifindex = 0;
714
715	rtnl_lock();
716	if (nkl->dev)
717		ifindex = nkl->dev->ifindex;
718	rtnl_unlock();
719
720	info->netkit.ifindex = ifindex;
721	info->netkit.attach_type = nkl->location;
722	return 0;
723}
724
725static int netkit_link_detach(struct bpf_link *link)
726{
727	netkit_link_release(link);
728	return 0;
729}
730
731static const struct bpf_link_ops netkit_link_lops = {
732	.release	= netkit_link_release,
733	.detach		= netkit_link_detach,
734	.dealloc	= netkit_link_dealloc,
735	.update_prog	= netkit_link_update,
736	.show_fdinfo	= netkit_link_fdinfo,
737	.fill_link_info	= netkit_link_fill_info,
738};
739
740static int netkit_link_init(struct netkit_link *nkl,
741			    struct bpf_link_primer *link_primer,
742			    const union bpf_attr *attr,
743			    struct net_device *dev,
744			    struct bpf_prog *prog)
745{
746	bpf_link_init(&nkl->link, BPF_LINK_TYPE_NETKIT,
747		      &netkit_link_lops, prog);
748	nkl->location = attr->link_create.attach_type;
749	nkl->dev = dev;
750	return bpf_link_prime(&nkl->link, link_primer);
751}
752
753int netkit_link_attach(const union bpf_attr *attr, struct bpf_prog *prog)
754{
755	struct bpf_link_primer link_primer;
756	struct netkit_link *nkl;
757	struct net_device *dev;
758	int ret;
759
760	rtnl_lock();
761	dev = netkit_dev_fetch(current->nsproxy->net_ns,
762			       attr->link_create.target_ifindex,
763			       attr->link_create.attach_type);
764	if (IS_ERR(dev)) {
765		ret = PTR_ERR(dev);
766		goto out;
767	}
768	nkl = kzalloc(sizeof(*nkl), GFP_KERNEL_ACCOUNT);
769	if (!nkl) {
770		ret = -ENOMEM;
771		goto out;
772	}
773	ret = netkit_link_init(nkl, &link_primer, attr, dev, prog);
774	if (ret) {
775		kfree(nkl);
776		goto out;
777	}
778	ret = netkit_link_prog_attach(&nkl->link,
779				      attr->link_create.flags,
780				      attr->link_create.netkit.relative_fd,
781				      attr->link_create.netkit.expected_revision);
782	if (ret) {
783		nkl->dev = NULL;
784		bpf_link_cleanup(&link_primer);
785		goto out;
786	}
787	ret = bpf_link_settle(&link_primer);
788out:
789	rtnl_unlock();
790	return ret;
791}
792
793static void netkit_release_all(struct net_device *dev)
794{
795	struct bpf_mprog_entry *entry;
796	struct bpf_tuple tuple = {};
797	struct bpf_mprog_fp *fp;
798	struct bpf_mprog_cp *cp;
799
800	entry = netkit_entry_fetch(dev, false);
801	if (!entry)
802		return;
803	netkit_entry_update(dev, NULL);
804	netkit_entry_sync();
805	bpf_mprog_foreach_tuple(entry, fp, cp, tuple) {
806		if (tuple.link)
807			netkit_link(tuple.link)->dev = NULL;
808		else
809			bpf_prog_put(tuple.prog);
810	}
811}
812
813static void netkit_uninit(struct net_device *dev)
814{
815	netkit_release_all(dev);
816}
817
818static void netkit_del_link(struct net_device *dev, struct list_head *head)
819{
820	struct netkit *nk = netkit_priv(dev);
821	struct net_device *peer = rtnl_dereference(nk->peer);
822
823	RCU_INIT_POINTER(nk->peer, NULL);
824	unregister_netdevice_queue(dev, head);
825	if (peer) {
826		nk = netkit_priv(peer);
827		RCU_INIT_POINTER(nk->peer, NULL);
828		unregister_netdevice_queue(peer, head);
829	}
830}
831
832static int netkit_change_link(struct net_device *dev, struct nlattr *tb[],
833			      struct nlattr *data[],
834			      struct netlink_ext_ack *extack)
835{
836	struct netkit *nk = netkit_priv(dev);
837	struct net_device *peer = rtnl_dereference(nk->peer);
838	enum netkit_action policy;
839	struct nlattr *attr;
840	int err;
841
842	if (!nk->primary) {
843		NL_SET_ERR_MSG(extack,
844			       "netkit link settings can be changed only through the primary device");
845		return -EACCES;
846	}
847
848	if (data[IFLA_NETKIT_MODE]) {
849		NL_SET_ERR_MSG_ATTR(extack, data[IFLA_NETKIT_MODE],
850				    "netkit link operating mode cannot be changed after device creation");
851		return -EACCES;
852	}
853
854	if (data[IFLA_NETKIT_PEER_INFO]) {
855		NL_SET_ERR_MSG_ATTR(extack, data[IFLA_NETKIT_PEER_INFO],
856				    "netkit peer info cannot be changed after device creation");
857		return -EINVAL;
858	}
859
860	if (data[IFLA_NETKIT_POLICY]) {
861		attr = data[IFLA_NETKIT_POLICY];
862		policy = nla_get_u32(attr);
863		err = netkit_check_policy(policy, attr, extack);
864		if (err)
865			return err;
866		WRITE_ONCE(nk->policy, policy);
867	}
868
869	if (data[IFLA_NETKIT_PEER_POLICY]) {
870		err = -EOPNOTSUPP;
871		attr = data[IFLA_NETKIT_PEER_POLICY];
872		policy = nla_get_u32(attr);
873		if (peer)
874			err = netkit_check_policy(policy, attr, extack);
875		if (err)
876			return err;
877		nk = netkit_priv(peer);
878		WRITE_ONCE(nk->policy, policy);
879	}
880
881	return 0;
882}
883
884static size_t netkit_get_size(const struct net_device *dev)
885{
886	return nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_POLICY */
887	       nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_PEER_POLICY */
888	       nla_total_size(sizeof(u8))  + /* IFLA_NETKIT_PRIMARY */
889	       nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_MODE */
890	       0;
891}
892
893static int netkit_fill_info(struct sk_buff *skb, const struct net_device *dev)
894{
895	struct netkit *nk = netkit_priv(dev);
896	struct net_device *peer = rtnl_dereference(nk->peer);
897
898	if (nla_put_u8(skb, IFLA_NETKIT_PRIMARY, nk->primary))
899		return -EMSGSIZE;
900	if (nla_put_u32(skb, IFLA_NETKIT_POLICY, nk->policy))
901		return -EMSGSIZE;
902	if (nla_put_u32(skb, IFLA_NETKIT_MODE, nk->mode))
903		return -EMSGSIZE;
904
905	if (peer) {
906		nk = netkit_priv(peer);
907		if (nla_put_u32(skb, IFLA_NETKIT_PEER_POLICY, nk->policy))
908			return -EMSGSIZE;
909	}
910
911	return 0;
912}
913
914static const struct nla_policy netkit_policy[IFLA_NETKIT_MAX + 1] = {
915	[IFLA_NETKIT_PEER_INFO]		= { .len = sizeof(struct ifinfomsg) },
916	[IFLA_NETKIT_POLICY]		= { .type = NLA_U32 },
917	[IFLA_NETKIT_MODE]		= { .type = NLA_U32 },
918	[IFLA_NETKIT_PEER_POLICY]	= { .type = NLA_U32 },
919	[IFLA_NETKIT_PRIMARY]		= { .type = NLA_REJECT,
920					    .reject_message = "Primary attribute is read-only" },
921};
922
923static struct rtnl_link_ops netkit_link_ops = {
924	.kind		= DRV_NAME,
925	.priv_size	= sizeof(struct netkit),
926	.setup		= netkit_setup,
927	.newlink	= netkit_new_link,
928	.dellink	= netkit_del_link,
929	.changelink	= netkit_change_link,
930	.get_link_net	= netkit_get_link_net,
931	.get_size	= netkit_get_size,
932	.fill_info	= netkit_fill_info,
933	.policy		= netkit_policy,
934	.validate	= netkit_validate,
935	.maxtype	= IFLA_NETKIT_MAX,
936};
937
938static __init int netkit_init(void)
939{
940	BUILD_BUG_ON((int)NETKIT_NEXT != (int)TCX_NEXT ||
941		     (int)NETKIT_PASS != (int)TCX_PASS ||
942		     (int)NETKIT_DROP != (int)TCX_DROP ||
943		     (int)NETKIT_REDIRECT != (int)TCX_REDIRECT);
944
945	return rtnl_link_register(&netkit_link_ops);
946}
947
948static __exit void netkit_exit(void)
949{
950	rtnl_link_unregister(&netkit_link_ops);
951}
952
953module_init(netkit_init);
954module_exit(netkit_exit);
955
956MODULE_DESCRIPTION("BPF-programmable network device");
957MODULE_AUTHOR("Daniel Borkmann <daniel@iogearbox.net>");
958MODULE_AUTHOR("Nikolay Aleksandrov <razor@blackwall.org>");
959MODULE_LICENSE("GPL");
960MODULE_ALIAS_RTNL_LINK(DRV_NAME);
961