1/* SPDX-License-Identifier: ISC
2 *
3 * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5 * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6 * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
7 * Copyright (c) 2022 The FreeBSD Foundation
8 */
9
10#include "opt_inet.h"
11#include "opt_inet6.h"
12
13#include <sys/param.h>
14#include <sys/systm.h>
15#include <sys/counter.h>
16#include <sys/gtaskqueue.h>
17#include <sys/jail.h>
18#include <sys/kernel.h>
19#include <sys/lock.h>
20#include <sys/mbuf.h>
21#include <sys/module.h>
22#include <sys/nv.h>
23#include <sys/priv.h>
24#include <sys/protosw.h>
25#include <sys/rmlock.h>
26#include <sys/rwlock.h>
27#include <sys/smp.h>
28#include <sys/socket.h>
29#include <sys/socketvar.h>
30#include <sys/sockio.h>
31#include <sys/sysctl.h>
32#include <sys/sx.h>
33#include <machine/_inttypes.h>
34#include <net/bpf.h>
35#include <net/ethernet.h>
36#include <net/if.h>
37#include <net/if_clone.h>
38#include <net/if_types.h>
39#include <net/if_var.h>
40#include <net/netisr.h>
41#include <net/radix.h>
42#include <netinet/in.h>
43#include <netinet6/in6_var.h>
44#include <netinet/ip.h>
45#include <netinet/ip6.h>
46#include <netinet/ip_icmp.h>
47#include <netinet/icmp6.h>
48#include <netinet/udp_var.h>
49#include <netinet6/nd6.h>
50
51#include "wg_noise.h"
52#include "wg_cookie.h"
53#include "version.h"
54#include "if_wg.h"
55
56#define DEFAULT_MTU		(ETHERMTU - 80)
57#define MAX_MTU			(IF_MAXMTU - 80)
58
59#define MAX_STAGED_PKT		128
60#define MAX_QUEUED_PKT		1024
61#define MAX_QUEUED_PKT_MASK	(MAX_QUEUED_PKT - 1)
62
63#define MAX_QUEUED_HANDSHAKES	4096
64
65#define REKEY_TIMEOUT_JITTER	334 /* 1/3 sec, round for arc4random_uniform */
66#define MAX_TIMER_HANDSHAKES	(90 / REKEY_TIMEOUT)
67#define NEW_HANDSHAKE_TIMEOUT	(REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68#define UNDERLOAD_TIMEOUT	1
69
70#define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71
72/* First byte indicating packet type on the wire */
73#define WG_PKT_INITIATION htole32(1)
74#define WG_PKT_RESPONSE htole32(2)
75#define WG_PKT_COOKIE htole32(3)
76#define WG_PKT_DATA htole32(4)
77
78#define WG_PKT_PADDING		16
79#define WG_KEY_SIZE		32
80
81struct wg_pkt_initiation {
82	uint32_t		t;
83	uint32_t		s_idx;
84	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
85	uint8_t			es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86	uint8_t			ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87	struct cookie_macs	m;
88};
89
90struct wg_pkt_response {
91	uint32_t		t;
92	uint32_t		s_idx;
93	uint32_t		r_idx;
94	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
95	uint8_t			en[0 + NOISE_AUTHTAG_LEN];
96	struct cookie_macs	m;
97};
98
99struct wg_pkt_cookie {
100	uint32_t		t;
101	uint32_t		r_idx;
102	uint8_t			nonce[COOKIE_NONCE_SIZE];
103	uint8_t			ec[COOKIE_ENCRYPTED_SIZE];
104};
105
106struct wg_pkt_data {
107	uint32_t		t;
108	uint32_t		r_idx;
109	uint64_t		nonce;
110	uint8_t			buf[];
111};
112
113struct wg_endpoint {
114	union {
115		struct sockaddr		r_sa;
116		struct sockaddr_in	r_sin;
117#ifdef INET6
118		struct sockaddr_in6	r_sin6;
119#endif
120	} e_remote;
121	union {
122		struct in_addr		l_in;
123#ifdef INET6
124		struct in6_pktinfo	l_pktinfo6;
125#define l_in6 l_pktinfo6.ipi6_addr
126#endif
127	} e_local;
128};
129
130struct aip_addr {
131	uint8_t		length;
132	union {
133		uint8_t		bytes[16];
134		uint32_t	ip;
135		uint32_t	ip6[4];
136		struct in_addr	in;
137		struct in6_addr	in6;
138	};
139};
140
141struct wg_aip {
142	struct radix_node	 a_nodes[2];
143	LIST_ENTRY(wg_aip)	 a_entry;
144	struct aip_addr		 a_addr;
145	struct aip_addr		 a_mask;
146	struct wg_peer		*a_peer;
147	sa_family_t		 a_af;
148};
149
150struct wg_packet {
151	STAILQ_ENTRY(wg_packet)	 p_serial;
152	STAILQ_ENTRY(wg_packet)	 p_parallel;
153	struct wg_endpoint	 p_endpoint;
154	struct noise_keypair	*p_keypair;
155	uint64_t		 p_nonce;
156	struct mbuf		*p_mbuf;
157	int			 p_mtu;
158	sa_family_t		 p_af;
159	enum wg_ring_state {
160		WG_PACKET_UNCRYPTED,
161		WG_PACKET_CRYPTED,
162		WG_PACKET_DEAD,
163	}			 p_state;
164};
165
166STAILQ_HEAD(wg_packet_list, wg_packet);
167
168struct wg_queue {
169	struct mtx		 q_mtx;
170	struct wg_packet_list	 q_queue;
171	size_t			 q_len;
172};
173
174struct wg_peer {
175	TAILQ_ENTRY(wg_peer)		 p_entry;
176	uint64_t			 p_id;
177	struct wg_softc			*p_sc;
178
179	struct noise_remote		*p_remote;
180	struct cookie_maker		 p_cookie;
181
182	struct rwlock			 p_endpoint_lock;
183	struct wg_endpoint		 p_endpoint;
184
185	struct wg_queue	 		 p_stage_queue;
186	struct wg_queue	 		 p_encrypt_serial;
187	struct wg_queue	 		 p_decrypt_serial;
188
189	bool				 p_enabled;
190	bool				 p_need_another_keepalive;
191	uint16_t			 p_persistent_keepalive_interval;
192	struct callout			 p_new_handshake;
193	struct callout			 p_send_keepalive;
194	struct callout			 p_retry_handshake;
195	struct callout			 p_zero_key_material;
196	struct callout			 p_persistent_keepalive;
197
198	struct mtx			 p_handshake_mtx;
199	struct timespec			 p_handshake_complete;	/* nanotime */
200	int				 p_handshake_retries;
201
202	struct grouptask		 p_send;
203	struct grouptask		 p_recv;
204
205	counter_u64_t			 p_tx_bytes;
206	counter_u64_t			 p_rx_bytes;
207
208	LIST_HEAD(, wg_aip)		 p_aips;
209	size_t				 p_aips_num;
210};
211
212struct wg_socket {
213	struct socket	*so_so4;
214	struct socket	*so_so6;
215	uint32_t	 so_user_cookie;
216	int		 so_fibnum;
217	in_port_t	 so_port;
218};
219
220struct wg_softc {
221	LIST_ENTRY(wg_softc)	 sc_entry;
222	if_t			 sc_ifp;
223	int			 sc_flags;
224
225	struct ucred		*sc_ucred;
226	struct wg_socket	 sc_socket;
227
228	TAILQ_HEAD(,wg_peer)	 sc_peers;
229	size_t			 sc_peers_num;
230
231	struct noise_local	*sc_local;
232	struct cookie_checker	 sc_cookie;
233
234	struct radix_node_head	*sc_aip4;
235	struct radix_node_head	*sc_aip6;
236
237	struct grouptask	 sc_handshake;
238	struct wg_queue		 sc_handshake_queue;
239
240	struct grouptask	*sc_encrypt;
241	struct grouptask	*sc_decrypt;
242	struct wg_queue		 sc_encrypt_parallel;
243	struct wg_queue		 sc_decrypt_parallel;
244	u_int			 sc_encrypt_last_cpu;
245	u_int			 sc_decrypt_last_cpu;
246
247	struct sx		 sc_lock;
248};
249
250#define	WGF_DYING	0x0001
251
252#define MAX_LOOPS	8
253#define MTAG_WGLOOP	0x77676c70 /* wglp */
254
255#define	GROUPTASK_DRAIN(gtask)			\
256	gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257
258#define BPF_MTAP2_AF(ifp, m, af) do { \
259		uint32_t __bpf_tap_af = (af); \
260		BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261	} while (0)
262
263static int clone_count;
264static uma_zone_t wg_packet_zone;
265static volatile unsigned long peer_counter = 0;
266static const char wgname[] = "wg";
267static unsigned wg_osd_jail_slot;
268
269static struct sx wg_sx;
270SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271
272static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273
274static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275
276MALLOC_DEFINE(M_WG, "WG", "wireguard");
277
278VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279
280#define	V_wg_cloner	VNET(wg_cloner)
281#define	WG_CAPS		IFCAP_LINKSTATE
282
283struct wg_timespec64 {
284	uint64_t	tv_sec;
285	uint64_t	tv_nsec;
286};
287
288static int wg_socket_init(struct wg_softc *, in_port_t);
289static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291static void wg_socket_uninit(struct wg_softc *);
292static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294static int wg_socket_set_fibnum(struct wg_softc *, int);
295static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296static void wg_timers_enable(struct wg_peer *);
297static void wg_timers_disable(struct wg_peer *);
298static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300static void wg_timers_event_data_sent(struct wg_peer *);
301static void wg_timers_event_data_received(struct wg_peer *);
302static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305static void wg_timers_event_handshake_initiated(struct wg_peer *);
306static void wg_timers_event_handshake_complete(struct wg_peer *);
307static void wg_timers_event_session_derived(struct wg_peer *);
308static void wg_timers_event_want_initiation(struct wg_peer *);
309static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310static void wg_timers_run_retry_handshake(void *);
311static void wg_timers_run_send_keepalive(void *);
312static void wg_timers_run_new_handshake(void *);
313static void wg_timers_run_zero_key_material(void *);
314static void wg_timers_run_persistent_keepalive(void *);
315static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t, const void *, uint8_t);
316static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
317static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
318static struct wg_peer *wg_peer_alloc(struct wg_softc *, const uint8_t [WG_KEY_SIZE]);
319static void wg_peer_free_deferred(struct noise_remote *);
320static void wg_peer_destroy(struct wg_peer *);
321static void wg_peer_destroy_all(struct wg_softc *);
322static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
323static void wg_send_initiation(struct wg_peer *);
324static void wg_send_response(struct wg_peer *);
325static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
326static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
327static void wg_peer_clear_src(struct wg_peer *);
328static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
329static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
330static void wg_send_keepalive(struct wg_peer *);
331static void wg_handshake(struct wg_softc *, struct wg_packet *);
332static void wg_encrypt(struct wg_softc *, struct wg_packet *);
333static void wg_decrypt(struct wg_softc *, struct wg_packet *);
334static void wg_softc_handshake_receive(struct wg_softc *);
335static void wg_softc_decrypt(struct wg_softc *);
336static void wg_softc_encrypt(struct wg_softc *);
337static void wg_encrypt_dispatch(struct wg_softc *);
338static void wg_decrypt_dispatch(struct wg_softc *);
339static void wg_deliver_out(struct wg_peer *);
340static void wg_deliver_in(struct wg_peer *);
341static struct wg_packet *wg_packet_alloc(struct mbuf *);
342static void wg_packet_free(struct wg_packet *);
343static void wg_queue_init(struct wg_queue *, const char *);
344static void wg_queue_deinit(struct wg_queue *);
345static size_t wg_queue_len(struct wg_queue *);
346static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
347static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
348static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
349static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
350static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
351static void wg_queue_purge(struct wg_queue *);
352static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
353static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
354static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
355static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
356static void wg_peer_send_staged(struct wg_peer *);
357static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
358	struct ifc_data *ifd, if_t *ifpp);
359static void wg_qflush(if_t);
360static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
361static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
362static int wg_transmit(if_t, struct mbuf *);
363static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
364static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
365	uint32_t flags);
366static bool wgc_privileged(struct wg_softc *);
367static int wgc_get(struct wg_softc *, struct wg_data_io *);
368static int wgc_set(struct wg_softc *, struct wg_data_io *);
369static int wg_up(struct wg_softc *);
370static void wg_down(struct wg_softc *);
371static void wg_reassign(if_t, struct vnet *, char *unused);
372static void wg_init(void *);
373static int wg_ioctl(if_t, u_long, caddr_t);
374static void vnet_wg_init(const void *);
375static void vnet_wg_uninit(const void *);
376static int wg_module_init(void);
377static void wg_module_deinit(void);
378
379/* TODO Peer */
380static struct wg_peer *
381wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE])
382{
383	struct wg_peer *peer;
384
385	sx_assert(&sc->sc_lock, SX_XLOCKED);
386
387	peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
388	peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
389	peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
390	peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
391	peer->p_id = peer_counter++;
392	peer->p_sc = sc;
393
394	cookie_maker_init(&peer->p_cookie, pub_key);
395
396	rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
397
398	wg_queue_init(&peer->p_stage_queue, "stageq");
399	wg_queue_init(&peer->p_encrypt_serial, "txq");
400	wg_queue_init(&peer->p_decrypt_serial, "rxq");
401
402	peer->p_enabled = false;
403	peer->p_need_another_keepalive = false;
404	peer->p_persistent_keepalive_interval = 0;
405	callout_init(&peer->p_new_handshake, true);
406	callout_init(&peer->p_send_keepalive, true);
407	callout_init(&peer->p_retry_handshake, true);
408	callout_init(&peer->p_persistent_keepalive, true);
409	callout_init(&peer->p_zero_key_material, true);
410
411	mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
412	bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
413	peer->p_handshake_retries = 0;
414
415	GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
416	taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
417	GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
418	taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
419
420	LIST_INIT(&peer->p_aips);
421	peer->p_aips_num = 0;
422
423	return (peer);
424}
425
426static void
427wg_peer_free_deferred(struct noise_remote *r)
428{
429	struct wg_peer *peer = noise_remote_arg(r);
430
431	/* While there are no references remaining, we may still have
432	 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
433	 * needs to check the queue. We should wait for them and then free. */
434	GROUPTASK_DRAIN(&peer->p_recv);
435	GROUPTASK_DRAIN(&peer->p_send);
436	taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
437	taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
438
439	wg_queue_deinit(&peer->p_decrypt_serial);
440	wg_queue_deinit(&peer->p_encrypt_serial);
441	wg_queue_deinit(&peer->p_stage_queue);
442
443	counter_u64_free(peer->p_tx_bytes);
444	counter_u64_free(peer->p_rx_bytes);
445	rw_destroy(&peer->p_endpoint_lock);
446	mtx_destroy(&peer->p_handshake_mtx);
447
448	cookie_maker_free(&peer->p_cookie);
449
450	free(peer, M_WG);
451}
452
453static void
454wg_peer_destroy(struct wg_peer *peer)
455{
456	struct wg_softc *sc = peer->p_sc;
457	sx_assert(&sc->sc_lock, SX_XLOCKED);
458
459	/* Disable remote and timers. This will prevent any new handshakes
460	 * occuring. */
461	noise_remote_disable(peer->p_remote);
462	wg_timers_disable(peer);
463
464	/* Now we can remove all allowed IPs so no more packets will be routed
465	 * to the peer. */
466	wg_aip_remove_all(sc, peer);
467
468	/* Remove peer from the interface, then free. Some references may still
469	 * exist to p_remote, so noise_remote_free will wait until they're all
470	 * put to call wg_peer_free_deferred. */
471	sc->sc_peers_num--;
472	TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
473	DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
474	noise_remote_free(peer->p_remote, wg_peer_free_deferred);
475}
476
477static void
478wg_peer_destroy_all(struct wg_softc *sc)
479{
480	struct wg_peer *peer, *tpeer;
481	TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
482		wg_peer_destroy(peer);
483}
484
485static void
486wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
487{
488	MPASS(e->e_remote.r_sa.sa_family != 0);
489	if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
490		return;
491
492	rw_wlock(&peer->p_endpoint_lock);
493	peer->p_endpoint = *e;
494	rw_wunlock(&peer->p_endpoint_lock);
495}
496
497static void
498wg_peer_clear_src(struct wg_peer *peer)
499{
500	rw_wlock(&peer->p_endpoint_lock);
501	bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
502	rw_wunlock(&peer->p_endpoint_lock);
503}
504
505static void
506wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
507{
508	rw_rlock(&peer->p_endpoint_lock);
509	*e = peer->p_endpoint;
510	rw_runlock(&peer->p_endpoint_lock);
511}
512
513/* Allowed IP */
514static int
515wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr)
516{
517	struct radix_node_head	*root;
518	struct radix_node	*node;
519	struct wg_aip		*aip;
520	int			 ret = 0;
521
522	aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
523	aip->a_peer = peer;
524	aip->a_af = af;
525
526	switch (af) {
527#ifdef INET
528	case AF_INET:
529		if (cidr > 32) cidr = 32;
530		root = sc->sc_aip4;
531		aip->a_addr.in = *(const struct in_addr *)addr;
532		aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
533		aip->a_addr.ip &= aip->a_mask.ip;
534		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
535		break;
536#endif
537#ifdef INET6
538	case AF_INET6:
539		if (cidr > 128) cidr = 128;
540		root = sc->sc_aip6;
541		aip->a_addr.in6 = *(const struct in6_addr *)addr;
542		in6_prefixlen2mask(&aip->a_mask.in6, cidr);
543		for (int i = 0; i < 4; i++)
544			aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
545		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
546		break;
547#endif
548	default:
549		free(aip, M_WG);
550		return (EAFNOSUPPORT);
551	}
552
553	RADIX_NODE_HEAD_LOCK(root);
554	node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
555	if (node == aip->a_nodes) {
556		LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
557		peer->p_aips_num++;
558	} else if (!node)
559		node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
560	if (!node) {
561		free(aip, M_WG);
562		ret = ENOMEM;
563	} else if (node != aip->a_nodes) {
564		free(aip, M_WG);
565		aip = (struct wg_aip *)node;
566		if (aip->a_peer != peer) {
567			LIST_REMOVE(aip, a_entry);
568			aip->a_peer->p_aips_num--;
569			aip->a_peer = peer;
570			LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
571			aip->a_peer->p_aips_num++;
572		}
573	}
574	RADIX_NODE_HEAD_UNLOCK(root);
575	return (ret);
576}
577
578static struct wg_peer *
579wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
580{
581	struct radix_node_head	*root;
582	struct radix_node	*node;
583	struct wg_peer		*peer;
584	struct aip_addr		 addr;
585	RADIX_NODE_HEAD_RLOCK_TRACKER;
586
587	switch (af) {
588	case AF_INET:
589		root = sc->sc_aip4;
590		memcpy(&addr.in, a, sizeof(addr.in));
591		addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
592		break;
593	case AF_INET6:
594		root = sc->sc_aip6;
595		memcpy(&addr.in6, a, sizeof(addr.in6));
596		addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
597		break;
598	default:
599		return NULL;
600	}
601
602	RADIX_NODE_HEAD_RLOCK(root);
603	node = root->rnh_matchaddr(&addr, &root->rh);
604	if (node != NULL) {
605		peer = ((struct wg_aip *)node)->a_peer;
606		noise_remote_ref(peer->p_remote);
607	} else {
608		peer = NULL;
609	}
610	RADIX_NODE_HEAD_RUNLOCK(root);
611
612	return (peer);
613}
614
615static void
616wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
617{
618	struct wg_aip		*aip, *taip;
619
620	RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
621	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
622		if (aip->a_af == AF_INET) {
623			if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
624				panic("failed to delete aip %p", aip);
625			LIST_REMOVE(aip, a_entry);
626			peer->p_aips_num--;
627			free(aip, M_WG);
628		}
629	}
630	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
631
632	RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
633	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
634		if (aip->a_af == AF_INET6) {
635			if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
636				panic("failed to delete aip %p", aip);
637			LIST_REMOVE(aip, a_entry);
638			peer->p_aips_num--;
639			free(aip, M_WG);
640		}
641	}
642	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
643
644	if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
645		panic("wg_aip_remove_all could not delete all %p", peer);
646}
647
648static int
649wg_socket_init(struct wg_softc *sc, in_port_t port)
650{
651	struct ucred *cred = sc->sc_ucred;
652	struct socket *so4 = NULL, *so6 = NULL;
653	int rc;
654
655	sx_assert(&sc->sc_lock, SX_XLOCKED);
656
657	if (!cred)
658		return (EBUSY);
659
660	/*
661	 * For socket creation, we use the creds of the thread that created the
662	 * tunnel rather than the current thread to maintain the semantics that
663	 * WireGuard has on Linux with network namespaces -- that the sockets
664	 * are created in their home vnet so that they can be configured and
665	 * functionally attached to a foreign vnet as the jail's only interface
666	 * to the network.
667	 */
668#ifdef INET
669	rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
670	if (rc)
671		goto out;
672
673	rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
674	/*
675	 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
676	 * This should never happen with a new socket.
677	 */
678	MPASS(rc == 0);
679#endif
680
681#ifdef INET6
682	rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
683	if (rc)
684		goto out;
685	rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
686	MPASS(rc == 0);
687#endif
688
689	if (sc->sc_socket.so_user_cookie) {
690		rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
691		if (rc)
692			goto out;
693	}
694	rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
695	if (rc)
696		goto out;
697
698	rc = wg_socket_bind(&so4, &so6, &port);
699	if (!rc) {
700		sc->sc_socket.so_port = port;
701		wg_socket_set(sc, so4, so6);
702	}
703out:
704	if (rc) {
705		if (so4 != NULL)
706			soclose(so4);
707		if (so6 != NULL)
708			soclose(so6);
709	}
710	return (rc);
711}
712
713static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
714{
715	int ret4 = 0, ret6 = 0;
716	struct sockopt sopt = {
717		.sopt_dir = SOPT_SET,
718		.sopt_level = SOL_SOCKET,
719		.sopt_name = name,
720		.sopt_val = val,
721		.sopt_valsize = len
722	};
723
724	if (so4)
725		ret4 = sosetopt(so4, &sopt);
726	if (so6)
727		ret6 = sosetopt(so6, &sopt);
728	return (ret4 ?: ret6);
729}
730
731static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
732{
733	struct wg_socket *so = &sc->sc_socket;
734	int ret;
735
736	sx_assert(&sc->sc_lock, SX_XLOCKED);
737	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
738	if (!ret)
739		so->so_user_cookie = user_cookie;
740	return (ret);
741}
742
743static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
744{
745	struct wg_socket *so = &sc->sc_socket;
746	int ret;
747
748	sx_assert(&sc->sc_lock, SX_XLOCKED);
749
750	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
751	if (!ret)
752		so->so_fibnum = fibnum;
753	return (ret);
754}
755
756static void
757wg_socket_uninit(struct wg_softc *sc)
758{
759	wg_socket_set(sc, NULL, NULL);
760}
761
762static void
763wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
764{
765	struct wg_socket *so = &sc->sc_socket;
766	struct socket *so4, *so6;
767
768	sx_assert(&sc->sc_lock, SX_XLOCKED);
769
770	so4 = atomic_load_ptr(&so->so_so4);
771	so6 = atomic_load_ptr(&so->so_so6);
772	atomic_store_ptr(&so->so_so4, new_so4);
773	atomic_store_ptr(&so->so_so6, new_so6);
774
775	if (!so4 && !so6)
776		return;
777	NET_EPOCH_WAIT();
778	if (so4)
779		soclose(so4);
780	if (so6)
781		soclose(so6);
782}
783
784static int
785wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
786{
787	struct socket *so4 = *in_so4, *so6 = *in_so6;
788	int ret4 = 0, ret6 = 0;
789	in_port_t port = *requested_port;
790	struct sockaddr_in sin = {
791		.sin_len = sizeof(struct sockaddr_in),
792		.sin_family = AF_INET,
793		.sin_port = htons(port)
794	};
795	struct sockaddr_in6 sin6 = {
796		.sin6_len = sizeof(struct sockaddr_in6),
797		.sin6_family = AF_INET6,
798		.sin6_port = htons(port)
799	};
800
801	if (so4) {
802		ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
803		if (ret4 && ret4 != EADDRNOTAVAIL)
804			return (ret4);
805		if (!ret4 && !sin.sin_port) {
806			struct sockaddr_in bound_sin =
807			    { .sin_len = sizeof(bound_sin) };
808			int ret;
809
810			ret = sosockaddr(so4, (struct sockaddr *)&bound_sin);
811			if (ret)
812				return (ret);
813			port = ntohs(bound_sin.sin_port);
814			sin6.sin6_port = bound_sin.sin_port;
815		}
816	}
817
818	if (so6) {
819		ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
820		if (ret6 && ret6 != EADDRNOTAVAIL)
821			return (ret6);
822		if (!ret6 && !sin6.sin6_port) {
823			struct sockaddr_in6 bound_sin6 =
824			    { .sin6_len = sizeof(bound_sin6) };
825			int ret;
826
827			ret = sosockaddr(so6, (struct sockaddr *)&bound_sin6);
828			if (ret)
829				return (ret);
830			port = ntohs(bound_sin6.sin6_port);
831		}
832	}
833
834	if (ret4 && ret6)
835		return (ret4);
836	*requested_port = port;
837	if (ret4 && !ret6 && so4) {
838		soclose(so4);
839		*in_so4 = NULL;
840	} else if (ret6 && !ret4 && so6) {
841		soclose(so6);
842		*in_so6 = NULL;
843	}
844	return (0);
845}
846
847static int
848wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
849{
850	struct epoch_tracker et;
851	struct sockaddr *sa;
852	struct wg_socket *so = &sc->sc_socket;
853	struct socket *so4, *so6;
854	struct mbuf *control = NULL;
855	int ret = 0;
856	size_t len = m->m_pkthdr.len;
857
858	/* Get local control address before locking */
859	if (e->e_remote.r_sa.sa_family == AF_INET) {
860		if (e->e_local.l_in.s_addr != INADDR_ANY)
861			control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
862			    sizeof(struct in_addr), IP_SENDSRCADDR,
863			    IPPROTO_IP, M_NOWAIT);
864#ifdef INET6
865	} else if (e->e_remote.r_sa.sa_family == AF_INET6) {
866		if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
867			control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
868			    sizeof(struct in6_pktinfo), IPV6_PKTINFO,
869			    IPPROTO_IPV6, M_NOWAIT);
870#endif
871	} else {
872		m_freem(m);
873		return (EAFNOSUPPORT);
874	}
875
876	/* Get remote address */
877	sa = &e->e_remote.r_sa;
878
879	NET_EPOCH_ENTER(et);
880	so4 = atomic_load_ptr(&so->so_so4);
881	so6 = atomic_load_ptr(&so->so_so6);
882	if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
883		ret = sosend(so4, sa, NULL, m, control, 0, curthread);
884	else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
885		ret = sosend(so6, sa, NULL, m, control, 0, curthread);
886	else {
887		ret = ENOTCONN;
888		m_freem(control);
889		m_freem(m);
890	}
891	NET_EPOCH_EXIT(et);
892	if (ret == 0) {
893		if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
894		if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
895	}
896	return (ret);
897}
898
899static void
900wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
901{
902	struct mbuf	*m;
903	int		 ret = 0;
904	bool		 retried = false;
905
906retry:
907	m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
908	if (!m) {
909		ret = ENOMEM;
910		goto out;
911	}
912	m_copyback(m, 0, len, buf);
913
914	if (ret == 0) {
915		ret = wg_send(sc, e, m);
916		/* Retry if we couldn't bind to e->e_local */
917		if (ret == EADDRNOTAVAIL && !retried) {
918			bzero(&e->e_local, sizeof(e->e_local));
919			retried = true;
920			goto retry;
921		}
922	} else {
923		ret = wg_send(sc, e, m);
924	}
925out:
926	if (ret)
927		DPRINTF(sc, "Unable to send packet: %d\n", ret);
928}
929
930/* Timers */
931static void
932wg_timers_enable(struct wg_peer *peer)
933{
934	atomic_store_bool(&peer->p_enabled, true);
935	wg_timers_run_persistent_keepalive(peer);
936}
937
938static void
939wg_timers_disable(struct wg_peer *peer)
940{
941	/* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
942	 * sure no new handshakes are created after the wait. This is because
943	 * all callout_resets (scheduling the callout) are guarded by
944	 * p_enabled. We can be sure all sections that read p_enabled and then
945	 * optionally call callout_reset are finished as they are surrounded by
946	 * NET_EPOCH_{ENTER,EXIT}.
947	 *
948	 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
949	 * not after), we stop all callouts leaving no callouts active.
950	 *
951	 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
952	 * performance impact is acceptable for the time being. */
953	atomic_store_bool(&peer->p_enabled, false);
954	NET_EPOCH_WAIT();
955	atomic_store_bool(&peer->p_need_another_keepalive, false);
956
957	callout_stop(&peer->p_new_handshake);
958	callout_stop(&peer->p_send_keepalive);
959	callout_stop(&peer->p_retry_handshake);
960	callout_stop(&peer->p_persistent_keepalive);
961	callout_stop(&peer->p_zero_key_material);
962}
963
964static void
965wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
966{
967	struct epoch_tracker et;
968	if (interval != peer->p_persistent_keepalive_interval) {
969		atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
970		NET_EPOCH_ENTER(et);
971		if (atomic_load_bool(&peer->p_enabled))
972			wg_timers_run_persistent_keepalive(peer);
973		NET_EPOCH_EXIT(et);
974	}
975}
976
977static void
978wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
979{
980	mtx_lock(&peer->p_handshake_mtx);
981	time->tv_sec = peer->p_handshake_complete.tv_sec;
982	time->tv_nsec = peer->p_handshake_complete.tv_nsec;
983	mtx_unlock(&peer->p_handshake_mtx);
984}
985
986static void
987wg_timers_event_data_sent(struct wg_peer *peer)
988{
989	struct epoch_tracker et;
990	NET_EPOCH_ENTER(et);
991	if (atomic_load_bool(&peer->p_enabled) &&
992	    !callout_pending(&peer->p_new_handshake))
993		callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
994		    NEW_HANDSHAKE_TIMEOUT * 1000 +
995		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
996		    wg_timers_run_new_handshake, peer);
997	NET_EPOCH_EXIT(et);
998}
999
1000static void
1001wg_timers_event_data_received(struct wg_peer *peer)
1002{
1003	struct epoch_tracker et;
1004	NET_EPOCH_ENTER(et);
1005	if (atomic_load_bool(&peer->p_enabled)) {
1006		if (!callout_pending(&peer->p_send_keepalive))
1007			callout_reset(&peer->p_send_keepalive,
1008			    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1009			    wg_timers_run_send_keepalive, peer);
1010		else
1011			atomic_store_bool(&peer->p_need_another_keepalive,
1012			    true);
1013	}
1014	NET_EPOCH_EXIT(et);
1015}
1016
1017static void
1018wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1019{
1020	callout_stop(&peer->p_send_keepalive);
1021}
1022
1023static void
1024wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1025{
1026	callout_stop(&peer->p_new_handshake);
1027}
1028
1029static void
1030wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1031{
1032	struct epoch_tracker et;
1033	uint16_t interval;
1034	NET_EPOCH_ENTER(et);
1035	interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1036	if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1037		callout_reset(&peer->p_persistent_keepalive,
1038		     MSEC_2_TICKS(interval * 1000),
1039		     wg_timers_run_persistent_keepalive, peer);
1040	NET_EPOCH_EXIT(et);
1041}
1042
1043static void
1044wg_timers_event_handshake_initiated(struct wg_peer *peer)
1045{
1046	struct epoch_tracker et;
1047	NET_EPOCH_ENTER(et);
1048	if (atomic_load_bool(&peer->p_enabled))
1049		callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1050		    REKEY_TIMEOUT * 1000 +
1051		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1052		    wg_timers_run_retry_handshake, peer);
1053	NET_EPOCH_EXIT(et);
1054}
1055
1056static void
1057wg_timers_event_handshake_complete(struct wg_peer *peer)
1058{
1059	struct epoch_tracker et;
1060	NET_EPOCH_ENTER(et);
1061	if (atomic_load_bool(&peer->p_enabled)) {
1062		mtx_lock(&peer->p_handshake_mtx);
1063		callout_stop(&peer->p_retry_handshake);
1064		peer->p_handshake_retries = 0;
1065		getnanotime(&peer->p_handshake_complete);
1066		mtx_unlock(&peer->p_handshake_mtx);
1067		wg_timers_run_send_keepalive(peer);
1068	}
1069	NET_EPOCH_EXIT(et);
1070}
1071
1072static void
1073wg_timers_event_session_derived(struct wg_peer *peer)
1074{
1075	struct epoch_tracker et;
1076	NET_EPOCH_ENTER(et);
1077	if (atomic_load_bool(&peer->p_enabled))
1078		callout_reset(&peer->p_zero_key_material,
1079		    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1080		    wg_timers_run_zero_key_material, peer);
1081	NET_EPOCH_EXIT(et);
1082}
1083
1084static void
1085wg_timers_event_want_initiation(struct wg_peer *peer)
1086{
1087	struct epoch_tracker et;
1088	NET_EPOCH_ENTER(et);
1089	if (atomic_load_bool(&peer->p_enabled))
1090		wg_timers_run_send_initiation(peer, false);
1091	NET_EPOCH_EXIT(et);
1092}
1093
1094static void
1095wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1096{
1097	if (!is_retry)
1098		peer->p_handshake_retries = 0;
1099	if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1100		wg_send_initiation(peer);
1101}
1102
1103static void
1104wg_timers_run_retry_handshake(void *_peer)
1105{
1106	struct epoch_tracker et;
1107	struct wg_peer *peer = _peer;
1108
1109	mtx_lock(&peer->p_handshake_mtx);
1110	if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1111		peer->p_handshake_retries++;
1112		mtx_unlock(&peer->p_handshake_mtx);
1113
1114		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1115		    "after %d seconds, retrying (try %d)\n", peer->p_id,
1116		    REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1117		wg_peer_clear_src(peer);
1118		wg_timers_run_send_initiation(peer, true);
1119	} else {
1120		mtx_unlock(&peer->p_handshake_mtx);
1121
1122		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1123		    "after %d retries, giving up\n", peer->p_id,
1124		    MAX_TIMER_HANDSHAKES + 2);
1125
1126		callout_stop(&peer->p_send_keepalive);
1127		wg_queue_purge(&peer->p_stage_queue);
1128		NET_EPOCH_ENTER(et);
1129		if (atomic_load_bool(&peer->p_enabled) &&
1130		    !callout_pending(&peer->p_zero_key_material))
1131			callout_reset(&peer->p_zero_key_material,
1132			    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1133			    wg_timers_run_zero_key_material, peer);
1134		NET_EPOCH_EXIT(et);
1135	}
1136}
1137
1138static void
1139wg_timers_run_send_keepalive(void *_peer)
1140{
1141	struct epoch_tracker et;
1142	struct wg_peer *peer = _peer;
1143
1144	wg_send_keepalive(peer);
1145	NET_EPOCH_ENTER(et);
1146	if (atomic_load_bool(&peer->p_enabled) &&
1147	    atomic_load_bool(&peer->p_need_another_keepalive)) {
1148		atomic_store_bool(&peer->p_need_another_keepalive, false);
1149		callout_reset(&peer->p_send_keepalive,
1150		    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1151		    wg_timers_run_send_keepalive, peer);
1152	}
1153	NET_EPOCH_EXIT(et);
1154}
1155
1156static void
1157wg_timers_run_new_handshake(void *_peer)
1158{
1159	struct wg_peer *peer = _peer;
1160
1161	DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1162	    "stopped hearing back after %d seconds\n",
1163	    peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1164
1165	wg_peer_clear_src(peer);
1166	wg_timers_run_send_initiation(peer, false);
1167}
1168
1169static void
1170wg_timers_run_zero_key_material(void *_peer)
1171{
1172	struct wg_peer *peer = _peer;
1173
1174	DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1175	    "haven't received a new one in %d seconds\n",
1176	    peer->p_id, REJECT_AFTER_TIME * 3);
1177	noise_remote_keypairs_clear(peer->p_remote);
1178}
1179
1180static void
1181wg_timers_run_persistent_keepalive(void *_peer)
1182{
1183	struct wg_peer *peer = _peer;
1184
1185	if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1186		wg_send_keepalive(peer);
1187}
1188
1189/* TODO Handshake */
1190static void
1191wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1192{
1193	struct wg_endpoint endpoint;
1194
1195	counter_u64_add(peer->p_tx_bytes, len);
1196	wg_timers_event_any_authenticated_packet_traversal(peer);
1197	wg_timers_event_any_authenticated_packet_sent(peer);
1198	wg_peer_get_endpoint(peer, &endpoint);
1199	wg_send_buf(peer->p_sc, &endpoint, buf, len);
1200}
1201
1202static void
1203wg_send_initiation(struct wg_peer *peer)
1204{
1205	struct wg_pkt_initiation pkt;
1206
1207	if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1208	    pkt.es, pkt.ets) != 0)
1209		return;
1210
1211	DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1212
1213	pkt.t = WG_PKT_INITIATION;
1214	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1215	    sizeof(pkt) - sizeof(pkt.m));
1216	wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1217	wg_timers_event_handshake_initiated(peer);
1218}
1219
1220static void
1221wg_send_response(struct wg_peer *peer)
1222{
1223	struct wg_pkt_response pkt;
1224
1225	if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1226	    pkt.ue, pkt.en) != 0)
1227		return;
1228
1229	DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1230
1231	wg_timers_event_session_derived(peer);
1232	pkt.t = WG_PKT_RESPONSE;
1233	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1234	     sizeof(pkt)-sizeof(pkt.m));
1235	wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1236}
1237
1238static void
1239wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1240    struct wg_endpoint *e)
1241{
1242	struct wg_pkt_cookie	pkt;
1243
1244	DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1245
1246	pkt.t = WG_PKT_COOKIE;
1247	pkt.r_idx = idx;
1248
1249	cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1250	    pkt.ec, &e->e_remote.r_sa);
1251	wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1252}
1253
1254static void
1255wg_send_keepalive(struct wg_peer *peer)
1256{
1257	struct wg_packet *pkt;
1258	struct mbuf *m;
1259
1260	if (wg_queue_len(&peer->p_stage_queue) > 0)
1261		goto send;
1262	if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1263		return;
1264	if ((pkt = wg_packet_alloc(m)) == NULL) {
1265		m_freem(m);
1266		return;
1267	}
1268	wg_queue_push_staged(&peer->p_stage_queue, pkt);
1269	DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1270send:
1271	wg_peer_send_staged(peer);
1272}
1273
1274static void
1275wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1276{
1277	struct wg_pkt_initiation	*init;
1278	struct wg_pkt_response		*resp;
1279	struct wg_pkt_cookie		*cook;
1280	struct wg_endpoint		*e;
1281	struct wg_peer			*peer;
1282	struct mbuf			*m;
1283	struct noise_remote		*remote = NULL;
1284	int				 res;
1285	bool				 underload = false;
1286	static sbintime_t		 wg_last_underload; /* sbinuptime */
1287
1288	underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1289	if (underload) {
1290		wg_last_underload = getsbinuptime();
1291	} else if (wg_last_underload) {
1292		underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1293		if (!underload)
1294			wg_last_underload = 0;
1295	}
1296
1297	m = pkt->p_mbuf;
1298	e = &pkt->p_endpoint;
1299
1300	if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1301		goto error;
1302
1303	switch (*mtod(m, uint32_t *)) {
1304	case WG_PKT_INITIATION:
1305		init = mtod(m, struct wg_pkt_initiation *);
1306
1307		res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1308				init, sizeof(*init) - sizeof(init->m),
1309				underload, &e->e_remote.r_sa,
1310				if_getvnet(sc->sc_ifp));
1311
1312		if (res == EINVAL) {
1313			DPRINTF(sc, "Invalid initiation MAC\n");
1314			goto error;
1315		} else if (res == ECONNREFUSED) {
1316			DPRINTF(sc, "Handshake ratelimited\n");
1317			goto error;
1318		} else if (res == EAGAIN) {
1319			wg_send_cookie(sc, &init->m, init->s_idx, e);
1320			goto error;
1321		} else if (res != 0) {
1322			panic("unexpected response: %d\n", res);
1323		}
1324
1325		if (noise_consume_initiation(sc->sc_local, &remote,
1326		    init->s_idx, init->ue, init->es, init->ets) != 0) {
1327			DPRINTF(sc, "Invalid handshake initiation\n");
1328			goto error;
1329		}
1330
1331		peer = noise_remote_arg(remote);
1332
1333		DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1334
1335		wg_peer_set_endpoint(peer, e);
1336		wg_send_response(peer);
1337		break;
1338	case WG_PKT_RESPONSE:
1339		resp = mtod(m, struct wg_pkt_response *);
1340
1341		res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1342				resp, sizeof(*resp) - sizeof(resp->m),
1343				underload, &e->e_remote.r_sa,
1344				if_getvnet(sc->sc_ifp));
1345
1346		if (res == EINVAL) {
1347			DPRINTF(sc, "Invalid response MAC\n");
1348			goto error;
1349		} else if (res == ECONNREFUSED) {
1350			DPRINTF(sc, "Handshake ratelimited\n");
1351			goto error;
1352		} else if (res == EAGAIN) {
1353			wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1354			goto error;
1355		} else if (res != 0) {
1356			panic("unexpected response: %d\n", res);
1357		}
1358
1359		if (noise_consume_response(sc->sc_local, &remote,
1360		    resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1361			DPRINTF(sc, "Invalid handshake response\n");
1362			goto error;
1363		}
1364
1365		peer = noise_remote_arg(remote);
1366		DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1367
1368		wg_peer_set_endpoint(peer, e);
1369		wg_timers_event_session_derived(peer);
1370		wg_timers_event_handshake_complete(peer);
1371		break;
1372	case WG_PKT_COOKIE:
1373		cook = mtod(m, struct wg_pkt_cookie *);
1374
1375		if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1376			DPRINTF(sc, "Unknown cookie index\n");
1377			goto error;
1378		}
1379
1380		peer = noise_remote_arg(remote);
1381
1382		if (cookie_maker_consume_payload(&peer->p_cookie,
1383		    cook->nonce, cook->ec) == 0) {
1384			DPRINTF(sc, "Receiving cookie response\n");
1385		} else {
1386			DPRINTF(sc, "Could not decrypt cookie response\n");
1387			goto error;
1388		}
1389
1390		goto not_authenticated;
1391	default:
1392		panic("invalid packet in handshake queue");
1393	}
1394
1395	wg_timers_event_any_authenticated_packet_received(peer);
1396	wg_timers_event_any_authenticated_packet_traversal(peer);
1397
1398not_authenticated:
1399	counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1400	if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1401	if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1402error:
1403	if (remote != NULL)
1404		noise_remote_put(remote);
1405	wg_packet_free(pkt);
1406}
1407
1408static void
1409wg_softc_handshake_receive(struct wg_softc *sc)
1410{
1411	struct wg_packet *pkt;
1412	while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1413		wg_handshake(sc, pkt);
1414}
1415
1416static void
1417wg_mbuf_reset(struct mbuf *m)
1418{
1419
1420	struct m_tag *t, *tmp;
1421
1422	/*
1423	 * We want to reset the mbuf to a newly allocated state, containing
1424	 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1425	 * offer this anywhere, so we have to make it up as we go. If we can
1426	 * get this in kern/kern_mbuf.c, that would be best.
1427	 *
1428	 * Notice: this may break things unexpectedly but it is better to fail
1429	 *         closed in the extreme case than leak informtion in every
1430	 *         case.
1431	 *
1432	 * With that said, all this attempts to do is remove any extraneous
1433	 * information that could be present.
1434	 */
1435
1436	M_ASSERTPKTHDR(m);
1437
1438	m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1439
1440	M_HASHTYPE_CLEAR(m);
1441#ifdef NUMA
1442        m->m_pkthdr.numa_domain = M_NODOM;
1443#endif
1444	SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1445		if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1446		    t->m_tag_id != PACKET_TAG_MACLABEL)
1447			m_tag_delete(m, t);
1448	}
1449
1450	KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1451	    ("%s: mbuf %p has a send tag", __func__, m));
1452
1453	m->m_pkthdr.csum_flags = 0;
1454	m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1455	m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1456}
1457
1458static inline unsigned int
1459calculate_padding(struct wg_packet *pkt)
1460{
1461	unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1462
1463	/* Keepalive packets don't set p_mtu, but also have a length of zero. */
1464	if (__predict_false(pkt->p_mtu == 0)) {
1465		padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1466		    ~(WG_PKT_PADDING - 1);
1467		return (padded_size - last_unit);
1468	}
1469
1470	if (__predict_false(last_unit > pkt->p_mtu))
1471		last_unit %= pkt->p_mtu;
1472
1473	padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1474	if (pkt->p_mtu < padded_size)
1475		padded_size = pkt->p_mtu;
1476	return (padded_size - last_unit);
1477}
1478
1479static void
1480wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1481{
1482	static const uint8_t	 padding[WG_PKT_PADDING] = { 0 };
1483	struct wg_pkt_data	*data;
1484	struct wg_peer		*peer;
1485	struct noise_remote	*remote;
1486	struct mbuf		*m;
1487	uint32_t		 idx;
1488	unsigned int		 padlen;
1489	enum wg_ring_state	 state = WG_PACKET_DEAD;
1490
1491	remote = noise_keypair_remote(pkt->p_keypair);
1492	peer = noise_remote_arg(remote);
1493	m = pkt->p_mbuf;
1494
1495	/* Pad the packet */
1496	padlen = calculate_padding(pkt);
1497	if (padlen != 0 && !m_append(m, padlen, padding))
1498		goto out;
1499
1500	/* Do encryption */
1501	if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1502		goto out;
1503
1504	/* Put header into packet */
1505	M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1506	if (m == NULL)
1507		goto out;
1508	data = mtod(m, struct wg_pkt_data *);
1509	data->t = WG_PKT_DATA;
1510	data->r_idx = idx;
1511	data->nonce = htole64(pkt->p_nonce);
1512
1513	wg_mbuf_reset(m);
1514	state = WG_PACKET_CRYPTED;
1515out:
1516	pkt->p_mbuf = m;
1517	atomic_store_rel_int(&pkt->p_state, state);
1518	GROUPTASK_ENQUEUE(&peer->p_send);
1519	noise_remote_put(remote);
1520}
1521
1522static void
1523wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1524{
1525	struct wg_peer		*peer, *allowed_peer;
1526	struct noise_remote	*remote;
1527	struct mbuf		*m;
1528	int			 len;
1529	enum wg_ring_state	 state = WG_PACKET_DEAD;
1530
1531	remote = noise_keypair_remote(pkt->p_keypair);
1532	peer = noise_remote_arg(remote);
1533	m = pkt->p_mbuf;
1534
1535	/* Read nonce and then adjust to remove the header. */
1536	pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1537	m_adj(m, sizeof(struct wg_pkt_data));
1538
1539	if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1540		goto out;
1541
1542	/* A packet with length 0 is a keepalive packet */
1543	if (__predict_false(m->m_pkthdr.len == 0)) {
1544		DPRINTF(sc, "Receiving keepalive packet from peer "
1545		    "%" PRIu64 "\n", peer->p_id);
1546		state = WG_PACKET_CRYPTED;
1547		goto out;
1548	}
1549
1550	/*
1551	 * We can let the network stack handle the intricate validation of the
1552	 * IP header, we just worry about the sizeof and the version, so we can
1553	 * read the source address in wg_aip_lookup.
1554	 */
1555
1556	if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1557		if (pkt->p_af == AF_INET) {
1558			struct ip *ip = mtod(m, struct ip *);
1559			allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1560			len = ntohs(ip->ip_len);
1561			if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1562				m_adj(m, len - m->m_pkthdr.len);
1563		} else if (pkt->p_af == AF_INET6) {
1564			struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1565			allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1566			len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1567			if (len < m->m_pkthdr.len)
1568				m_adj(m, len - m->m_pkthdr.len);
1569		} else
1570			panic("determine_af_and_pullup returned unexpected value");
1571	} else {
1572		DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1573		goto out;
1574	}
1575
1576	/* We only want to compare the address, not dereference, so drop the ref. */
1577	if (allowed_peer != NULL)
1578		noise_remote_put(allowed_peer->p_remote);
1579
1580	if (__predict_false(peer != allowed_peer)) {
1581		DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1582		goto out;
1583	}
1584
1585	wg_mbuf_reset(m);
1586	state = WG_PACKET_CRYPTED;
1587out:
1588	pkt->p_mbuf = m;
1589	atomic_store_rel_int(&pkt->p_state, state);
1590	GROUPTASK_ENQUEUE(&peer->p_recv);
1591	noise_remote_put(remote);
1592}
1593
1594static void
1595wg_softc_decrypt(struct wg_softc *sc)
1596{
1597	struct wg_packet *pkt;
1598
1599	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1600		wg_decrypt(sc, pkt);
1601}
1602
1603static void
1604wg_softc_encrypt(struct wg_softc *sc)
1605{
1606	struct wg_packet *pkt;
1607
1608	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1609		wg_encrypt(sc, pkt);
1610}
1611
1612static void
1613wg_encrypt_dispatch(struct wg_softc *sc)
1614{
1615	/*
1616	 * The update to encrypt_last_cpu is racey such that we may
1617	 * reschedule the task for the same CPU multiple times, but
1618	 * the race doesn't really matter.
1619	 */
1620	u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1621	sc->sc_encrypt_last_cpu = cpu;
1622	GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1623}
1624
1625static void
1626wg_decrypt_dispatch(struct wg_softc *sc)
1627{
1628	u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1629	sc->sc_decrypt_last_cpu = cpu;
1630	GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1631}
1632
1633static void
1634wg_deliver_out(struct wg_peer *peer)
1635{
1636	struct wg_endpoint	 endpoint;
1637	struct wg_softc		*sc = peer->p_sc;
1638	struct wg_packet	*pkt;
1639	struct mbuf		*m;
1640	int			 rc, len;
1641
1642	wg_peer_get_endpoint(peer, &endpoint);
1643
1644	while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1645		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1646			goto error;
1647
1648		m = pkt->p_mbuf;
1649		pkt->p_mbuf = NULL;
1650
1651		len = m->m_pkthdr.len;
1652
1653		wg_timers_event_any_authenticated_packet_traversal(peer);
1654		wg_timers_event_any_authenticated_packet_sent(peer);
1655		rc = wg_send(sc, &endpoint, m);
1656		if (rc == 0) {
1657			if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1658				wg_timers_event_data_sent(peer);
1659			counter_u64_add(peer->p_tx_bytes, len);
1660		} else if (rc == EADDRNOTAVAIL) {
1661			wg_peer_clear_src(peer);
1662			wg_peer_get_endpoint(peer, &endpoint);
1663			goto error;
1664		} else {
1665			goto error;
1666		}
1667		wg_packet_free(pkt);
1668		if (noise_keep_key_fresh_send(peer->p_remote))
1669			wg_timers_event_want_initiation(peer);
1670		continue;
1671error:
1672		if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1673		wg_packet_free(pkt);
1674	}
1675}
1676
1677#ifdef DEV_NETMAP
1678/*
1679 * Hand a packet to the netmap RX ring, via netmap's
1680 * freebsd_generic_rx_handler().
1681 */
1682static void
1683wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1684{
1685	struct ether_header *eh;
1686
1687	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1688	if (__predict_false(m == NULL)) {
1689		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1690		return;
1691	}
1692
1693	eh = mtod(m, struct ether_header *);
1694	eh->ether_type = af == AF_INET ?
1695	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1696	memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1697	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1698	if_input(ifp, m);
1699}
1700#endif
1701
1702static void
1703wg_deliver_in(struct wg_peer *peer)
1704{
1705	struct wg_softc		*sc = peer->p_sc;
1706	if_t			 ifp = sc->sc_ifp;
1707	struct wg_packet	*pkt;
1708	struct mbuf		*m;
1709	struct epoch_tracker	 et;
1710	int			 af;
1711
1712	while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1713		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1714			goto error;
1715
1716		m = pkt->p_mbuf;
1717		if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1718			goto error;
1719
1720		if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1721			wg_timers_event_handshake_complete(peer);
1722
1723		wg_timers_event_any_authenticated_packet_received(peer);
1724		wg_timers_event_any_authenticated_packet_traversal(peer);
1725		wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1726
1727		counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1728		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1729		if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1730		if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1731		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1732
1733		if (m->m_pkthdr.len == 0)
1734			goto done;
1735
1736		af = pkt->p_af;
1737		MPASS(af == AF_INET || af == AF_INET6);
1738		pkt->p_mbuf = NULL;
1739
1740		m->m_pkthdr.rcvif = ifp;
1741
1742		NET_EPOCH_ENTER(et);
1743		BPF_MTAP2_AF(ifp, m, af);
1744
1745		CURVNET_SET(if_getvnet(ifp));
1746		M_SETFIB(m, if_getfib(ifp));
1747#ifdef DEV_NETMAP
1748		if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1749			wg_deliver_netmap(ifp, m, af);
1750		else
1751#endif
1752		if (af == AF_INET)
1753			netisr_dispatch(NETISR_IP, m);
1754		else if (af == AF_INET6)
1755			netisr_dispatch(NETISR_IPV6, m);
1756		CURVNET_RESTORE();
1757		NET_EPOCH_EXIT(et);
1758
1759		wg_timers_event_data_received(peer);
1760
1761done:
1762		if (noise_keep_key_fresh_recv(peer->p_remote))
1763			wg_timers_event_want_initiation(peer);
1764		wg_packet_free(pkt);
1765		continue;
1766error:
1767		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1768		wg_packet_free(pkt);
1769	}
1770}
1771
1772static struct wg_packet *
1773wg_packet_alloc(struct mbuf *m)
1774{
1775	struct wg_packet *pkt;
1776
1777	if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1778		return (NULL);
1779	pkt->p_mbuf = m;
1780	return (pkt);
1781}
1782
1783static void
1784wg_packet_free(struct wg_packet *pkt)
1785{
1786	if (pkt->p_keypair != NULL)
1787		noise_keypair_put(pkt->p_keypair);
1788	if (pkt->p_mbuf != NULL)
1789		m_freem(pkt->p_mbuf);
1790	uma_zfree(wg_packet_zone, pkt);
1791}
1792
1793static void
1794wg_queue_init(struct wg_queue *queue, const char *name)
1795{
1796	mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1797	STAILQ_INIT(&queue->q_queue);
1798	queue->q_len = 0;
1799}
1800
1801static void
1802wg_queue_deinit(struct wg_queue *queue)
1803{
1804	wg_queue_purge(queue);
1805	mtx_destroy(&queue->q_mtx);
1806}
1807
1808static size_t
1809wg_queue_len(struct wg_queue *queue)
1810{
1811	return (queue->q_len);
1812}
1813
1814static int
1815wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1816{
1817	int ret = 0;
1818	mtx_lock(&hs->q_mtx);
1819	if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1820		STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1821		hs->q_len++;
1822	} else {
1823		ret = ENOBUFS;
1824	}
1825	mtx_unlock(&hs->q_mtx);
1826	if (ret != 0)
1827		wg_packet_free(pkt);
1828	return (ret);
1829}
1830
1831static struct wg_packet *
1832wg_queue_dequeue_handshake(struct wg_queue *hs)
1833{
1834	struct wg_packet *pkt;
1835	mtx_lock(&hs->q_mtx);
1836	if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1837		STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1838		hs->q_len--;
1839	}
1840	mtx_unlock(&hs->q_mtx);
1841	return (pkt);
1842}
1843
1844static void
1845wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1846{
1847	struct wg_packet *old = NULL;
1848
1849	mtx_lock(&staged->q_mtx);
1850	if (staged->q_len >= MAX_STAGED_PKT) {
1851		old = STAILQ_FIRST(&staged->q_queue);
1852		STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1853		staged->q_len--;
1854	}
1855	STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1856	staged->q_len++;
1857	mtx_unlock(&staged->q_mtx);
1858
1859	if (old != NULL)
1860		wg_packet_free(old);
1861}
1862
1863static void
1864wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1865{
1866	struct wg_packet *pkt, *tpkt;
1867	STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1868		wg_queue_push_staged(staged, pkt);
1869}
1870
1871static void
1872wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1873{
1874	STAILQ_INIT(list);
1875	mtx_lock(&staged->q_mtx);
1876	STAILQ_CONCAT(list, &staged->q_queue);
1877	staged->q_len = 0;
1878	mtx_unlock(&staged->q_mtx);
1879}
1880
1881static void
1882wg_queue_purge(struct wg_queue *staged)
1883{
1884	struct wg_packet_list list;
1885	struct wg_packet *pkt, *tpkt;
1886	wg_queue_delist_staged(staged, &list);
1887	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1888		wg_packet_free(pkt);
1889}
1890
1891static int
1892wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1893{
1894	pkt->p_state = WG_PACKET_UNCRYPTED;
1895
1896	mtx_lock(&serial->q_mtx);
1897	if (serial->q_len < MAX_QUEUED_PKT) {
1898		serial->q_len++;
1899		STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1900	} else {
1901		mtx_unlock(&serial->q_mtx);
1902		wg_packet_free(pkt);
1903		return (ENOBUFS);
1904	}
1905	mtx_unlock(&serial->q_mtx);
1906
1907	mtx_lock(&parallel->q_mtx);
1908	if (parallel->q_len < MAX_QUEUED_PKT) {
1909		parallel->q_len++;
1910		STAILQ_INSERT_TAIL(&parallel->q_queue, pkt, p_parallel);
1911	} else {
1912		mtx_unlock(&parallel->q_mtx);
1913		pkt->p_state = WG_PACKET_DEAD;
1914		return (ENOBUFS);
1915	}
1916	mtx_unlock(&parallel->q_mtx);
1917
1918	return (0);
1919}
1920
1921static struct wg_packet *
1922wg_queue_dequeue_serial(struct wg_queue *serial)
1923{
1924	struct wg_packet *pkt = NULL;
1925	mtx_lock(&serial->q_mtx);
1926	if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
1927		serial->q_len--;
1928		pkt = STAILQ_FIRST(&serial->q_queue);
1929		STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
1930	}
1931	mtx_unlock(&serial->q_mtx);
1932	return (pkt);
1933}
1934
1935static struct wg_packet *
1936wg_queue_dequeue_parallel(struct wg_queue *parallel)
1937{
1938	struct wg_packet *pkt = NULL;
1939	mtx_lock(&parallel->q_mtx);
1940	if (parallel->q_len > 0) {
1941		parallel->q_len--;
1942		pkt = STAILQ_FIRST(&parallel->q_queue);
1943		STAILQ_REMOVE_HEAD(&parallel->q_queue, p_parallel);
1944	}
1945	mtx_unlock(&parallel->q_mtx);
1946	return (pkt);
1947}
1948
1949static bool
1950wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
1951    const struct sockaddr *sa, void *_sc)
1952{
1953#ifdef INET
1954	const struct sockaddr_in	*sin;
1955#endif
1956#ifdef INET6
1957	const struct sockaddr_in6	*sin6;
1958#endif
1959	struct noise_remote		*remote;
1960	struct wg_pkt_data		*data;
1961	struct wg_packet		*pkt;
1962	struct wg_peer			*peer;
1963	struct wg_softc			*sc = _sc;
1964	struct mbuf			*defragged;
1965
1966	defragged = m_defrag(m, M_NOWAIT);
1967	if (defragged)
1968		m = defragged;
1969	m = m_unshare(m, M_NOWAIT);
1970	if (!m) {
1971		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1972		return true;
1973	}
1974
1975	/* Caller provided us with `sa`, no need for this header. */
1976	m_adj(m, offset + sizeof(struct udphdr));
1977
1978	/* Pullup enough to read packet type */
1979	if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
1980		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1981		return true;
1982	}
1983
1984	if ((pkt = wg_packet_alloc(m)) == NULL) {
1985		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1986		m_freem(m);
1987		return true;
1988	}
1989
1990	/* Save send/recv address and port for later. */
1991	switch (sa->sa_family) {
1992#ifdef INET
1993	case AF_INET:
1994		sin = (const struct sockaddr_in *)sa;
1995		pkt->p_endpoint.e_remote.r_sin = sin[0];
1996		pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
1997		break;
1998#endif
1999#ifdef INET6
2000	case AF_INET6:
2001		sin6 = (const struct sockaddr_in6 *)sa;
2002		pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2003		pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2004		break;
2005#endif
2006	default:
2007		goto error;
2008	}
2009
2010	if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2011		*mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2012	    (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2013		*mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2014	    (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2015		*mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2016
2017		if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2018			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2019			DPRINTF(sc, "Dropping handshake packet\n");
2020		}
2021		GROUPTASK_ENQUEUE(&sc->sc_handshake);
2022	} else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2023	    NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2024
2025		/* Pullup whole header to read r_idx below. */
2026		if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2027			goto error;
2028
2029		data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2030		if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2031			goto error;
2032
2033		remote = noise_keypair_remote(pkt->p_keypair);
2034		peer = noise_remote_arg(remote);
2035		if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2036			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2037		wg_decrypt_dispatch(sc);
2038		noise_remote_put(remote);
2039	} else {
2040		goto error;
2041	}
2042	return true;
2043error:
2044	if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2045	wg_packet_free(pkt);
2046	return true;
2047}
2048
2049static void
2050wg_peer_send_staged(struct wg_peer *peer)
2051{
2052	struct wg_packet_list	 list;
2053	struct noise_keypair	*keypair;
2054	struct wg_packet	*pkt, *tpkt;
2055	struct wg_softc		*sc = peer->p_sc;
2056
2057	wg_queue_delist_staged(&peer->p_stage_queue, &list);
2058
2059	if (STAILQ_EMPTY(&list))
2060		return;
2061
2062	if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2063		goto error;
2064
2065	STAILQ_FOREACH(pkt, &list, p_parallel) {
2066		if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2067			goto error_keypair;
2068	}
2069	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2070		pkt->p_keypair = noise_keypair_ref(keypair);
2071		if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2072			if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2073	}
2074	wg_encrypt_dispatch(sc);
2075	noise_keypair_put(keypair);
2076	return;
2077
2078error_keypair:
2079	noise_keypair_put(keypair);
2080error:
2081	wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2082	wg_timers_event_want_initiation(peer);
2083}
2084
2085static inline void
2086xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2087{
2088	if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2089	switch (af) {
2090#ifdef INET
2091	case AF_INET:
2092		icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2093		if (pkt)
2094			pkt->p_mbuf = NULL;
2095		m = NULL;
2096		break;
2097#endif
2098#ifdef INET6
2099	case AF_INET6:
2100		icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2101		if (pkt)
2102			pkt->p_mbuf = NULL;
2103		m = NULL;
2104		break;
2105#endif
2106	}
2107	if (pkt)
2108		wg_packet_free(pkt);
2109	else if (m)
2110		m_freem(m);
2111}
2112
2113static int
2114wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2115{
2116	struct wg_packet	*pkt = NULL;
2117	struct wg_softc		*sc = if_getsoftc(ifp);
2118	struct wg_peer		*peer;
2119	int			 rc = 0;
2120	sa_family_t		 peer_af;
2121
2122	/* Work around lifetime issue in the ipv6 mld code. */
2123	if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2124		rc = ENXIO;
2125		goto err_xmit;
2126	}
2127
2128	if ((pkt = wg_packet_alloc(m)) == NULL) {
2129		rc = ENOBUFS;
2130		goto err_xmit;
2131	}
2132	pkt->p_mtu = mtu;
2133	pkt->p_af = af;
2134
2135	if (af == AF_INET) {
2136		peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2137	} else if (af == AF_INET6) {
2138		peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2139	} else {
2140		rc = EAFNOSUPPORT;
2141		goto err_xmit;
2142	}
2143
2144	BPF_MTAP2_AF(ifp, m, pkt->p_af);
2145
2146	if (__predict_false(peer == NULL)) {
2147		rc = ENETUNREACH;
2148		goto err_xmit;
2149	}
2150
2151	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2152		DPRINTF(sc, "Packet looped");
2153		rc = ELOOP;
2154		goto err_peer;
2155	}
2156
2157	peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2158	if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2159		DPRINTF(sc, "No valid endpoint has been configured or "
2160			    "discovered for peer %" PRIu64 "\n", peer->p_id);
2161		rc = EHOSTUNREACH;
2162		goto err_peer;
2163	}
2164
2165	wg_queue_push_staged(&peer->p_stage_queue, pkt);
2166	wg_peer_send_staged(peer);
2167	noise_remote_put(peer->p_remote);
2168	return (0);
2169
2170err_peer:
2171	noise_remote_put(peer->p_remote);
2172err_xmit:
2173	xmit_err(ifp, m, pkt, af);
2174	return (rc);
2175}
2176
2177static inline int
2178determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2179{
2180	u_char ipv;
2181	if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2182		*m = m_pullup(*m, sizeof(struct ip6_hdr));
2183	else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2184		*m = m_pullup(*m, sizeof(struct ip));
2185	else
2186		return (EAFNOSUPPORT);
2187	if (*m == NULL)
2188		return (ENOBUFS);
2189	ipv = mtod(*m, struct ip *)->ip_v;
2190	if (ipv == 4)
2191		*af = AF_INET;
2192	else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2193		*af = AF_INET6;
2194	else
2195		return (EAFNOSUPPORT);
2196	return (0);
2197}
2198
2199#ifdef DEV_NETMAP
2200static int
2201determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2202{
2203	struct ether_header *eh;
2204
2205	*m = m_pullup(*m, sizeof(struct ether_header));
2206	if (__predict_false(*m == NULL))
2207		return (ENOBUFS);
2208	eh = mtod(*m, struct ether_header *);
2209	*etp = ntohs(eh->ether_type);
2210	if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2211		return (EAFNOSUPPORT);
2212	return (0);
2213}
2214
2215/*
2216 * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2217 * transmit packets from the netmap TX ring.
2218 */
2219static int
2220wg_transmit(if_t ifp, struct mbuf *m)
2221{
2222	sa_family_t af;
2223	int et, ret;
2224	struct mbuf *defragged;
2225
2226	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2227	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2228
2229	defragged = m_defrag(m, M_NOWAIT);
2230	if (defragged)
2231		m = defragged;
2232	m = m_unshare(m, M_NOWAIT);
2233	if (!m) {
2234		xmit_err(ifp, m, NULL, AF_UNSPEC);
2235		return (ENOBUFS);
2236	}
2237
2238	ret = determine_ethertype_and_pullup(&m, &et);
2239	if (ret) {
2240		xmit_err(ifp, m, NULL, AF_UNSPEC);
2241		return (ret);
2242	}
2243	m_adj(m, sizeof(struct ether_header));
2244
2245	ret = determine_af_and_pullup(&m, &af);
2246	if (ret) {
2247		xmit_err(ifp, m, NULL, AF_UNSPEC);
2248		return (ret);
2249	}
2250
2251	/*
2252	 * netmap only gets to see transient errors, since it handles errors by
2253	 * refusing to advance the transmit ring and retrying later.
2254	 */
2255	ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2256	if (ret == ENOBUFS)
2257		return (ret);
2258	return (0);
2259}
2260
2261/*
2262 * This should only be invoked by netmap, via nm_os_send_up(), to process
2263 * packets from the host TX ring.
2264 */
2265static void
2266wg_if_input(if_t ifp, struct mbuf *m)
2267{
2268	int et;
2269
2270	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2271	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2272
2273	if (determine_ethertype_and_pullup(&m, &et) != 0) {
2274		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2275		m_freem(m);
2276		return;
2277	}
2278	CURVNET_SET(if_getvnet(ifp));
2279	switch (et) {
2280	case ETHERTYPE_IP:
2281		m_adj(m, sizeof(struct ether_header));
2282		netisr_dispatch(NETISR_IP, m);
2283		break;
2284	case ETHERTYPE_IPV6:
2285		m_adj(m, sizeof(struct ether_header));
2286		netisr_dispatch(NETISR_IPV6, m);
2287		break;
2288	default:
2289		__assert_unreachable();
2290	}
2291	CURVNET_RESTORE();
2292}
2293
2294/*
2295 * Deliver a packet to the host RX ring.  Because the interface is in netmap
2296 * mode, the if_transmit() call should pass the packet to netmap_transmit().
2297 */
2298static int
2299wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2300{
2301	struct ether_header *eh;
2302
2303	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2304	    MAX_LOOPS))) {
2305		printf("%s:%d\n", __func__, __LINE__);
2306		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2307		m_freem(m);
2308		return (ELOOP);
2309	}
2310
2311	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2312	if (__predict_false(m == NULL)) {
2313		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2314		return (ENOBUFS);
2315	}
2316
2317	eh = mtod(m, struct ether_header *);
2318	eh->ether_type = af == AF_INET ?
2319	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2320	memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2321	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2322	return (if_transmit(ifp, m));
2323}
2324#endif /* DEV_NETMAP */
2325
2326static int
2327wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2328{
2329	sa_family_t parsed_af;
2330	uint32_t af, mtu;
2331	int ret;
2332	struct mbuf *defragged;
2333
2334	/* BPF writes need to be handled specially. */
2335	if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2336		memcpy(&af, dst->sa_data, sizeof(af));
2337	else
2338		af = dst->sa_family;
2339	if (af == AF_UNSPEC) {
2340		xmit_err(ifp, m, NULL, af);
2341		return (EAFNOSUPPORT);
2342	}
2343
2344#ifdef DEV_NETMAP
2345	if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2346		return (wg_xmit_netmap(ifp, m, af));
2347#endif
2348
2349	defragged = m_defrag(m, M_NOWAIT);
2350	if (defragged)
2351		m = defragged;
2352	m = m_unshare(m, M_NOWAIT);
2353	if (!m) {
2354		xmit_err(ifp, m, NULL, AF_UNSPEC);
2355		return (ENOBUFS);
2356	}
2357
2358	ret = determine_af_and_pullup(&m, &parsed_af);
2359	if (ret) {
2360		xmit_err(ifp, m, NULL, AF_UNSPEC);
2361		return (ret);
2362	}
2363	if (parsed_af != af) {
2364		xmit_err(ifp, m, NULL, AF_UNSPEC);
2365		return (EAFNOSUPPORT);
2366	}
2367	mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2368	return (wg_xmit(ifp, m, parsed_af, mtu));
2369}
2370
2371static int
2372wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2373{
2374	uint8_t			 public[WG_KEY_SIZE];
2375	const void *pub_key, *preshared_key = NULL;
2376	const struct sockaddr *endpoint;
2377	int err;
2378	size_t size;
2379	struct noise_remote *remote;
2380	struct wg_peer *peer = NULL;
2381	bool need_insert = false;
2382
2383	sx_assert(&sc->sc_lock, SX_XLOCKED);
2384
2385	if (!nvlist_exists_binary(nvl, "public-key")) {
2386		return (EINVAL);
2387	}
2388	pub_key = nvlist_get_binary(nvl, "public-key", &size);
2389	if (size != WG_KEY_SIZE) {
2390		return (EINVAL);
2391	}
2392	if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2393	    bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2394		return (0); // Silently ignored; not actually a failure.
2395	}
2396	if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2397		peer = noise_remote_arg(remote);
2398	if (nvlist_exists_bool(nvl, "remove") &&
2399		nvlist_get_bool(nvl, "remove")) {
2400		if (remote != NULL) {
2401			wg_peer_destroy(peer);
2402			noise_remote_put(remote);
2403		}
2404		return (0);
2405	}
2406	if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2407		nvlist_get_bool(nvl, "replace-allowedips") &&
2408	    peer != NULL) {
2409
2410		wg_aip_remove_all(sc, peer);
2411	}
2412	if (peer == NULL) {
2413		peer = wg_peer_alloc(sc, pub_key);
2414		need_insert = true;
2415	}
2416	if (nvlist_exists_binary(nvl, "endpoint")) {
2417		endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2418		if (size > sizeof(peer->p_endpoint.e_remote)) {
2419			err = EINVAL;
2420			goto out;
2421		}
2422		memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2423	}
2424	if (nvlist_exists_binary(nvl, "preshared-key")) {
2425		preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2426		if (size != WG_KEY_SIZE) {
2427			err = EINVAL;
2428			goto out;
2429		}
2430		noise_remote_set_psk(peer->p_remote, preshared_key);
2431	}
2432	if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2433		uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2434		if (pki > UINT16_MAX) {
2435			err = EINVAL;
2436			goto out;
2437		}
2438		wg_timers_set_persistent_keepalive(peer, pki);
2439	}
2440	if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2441		const void *addr;
2442		uint64_t cidr;
2443		const nvlist_t * const * aipl;
2444		size_t allowedip_count;
2445
2446		aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2447		for (size_t idx = 0; idx < allowedip_count; idx++) {
2448			if (!nvlist_exists_number(aipl[idx], "cidr"))
2449				continue;
2450			cidr = nvlist_get_number(aipl[idx], "cidr");
2451			if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2452				addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2453				if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2454					err = EINVAL;
2455					goto out;
2456				}
2457				if ((err = wg_aip_add(sc, peer, AF_INET, addr, cidr)) != 0)
2458					goto out;
2459			} else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2460				addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2461				if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2462					err = EINVAL;
2463					goto out;
2464				}
2465				if ((err = wg_aip_add(sc, peer, AF_INET6, addr, cidr)) != 0)
2466					goto out;
2467			} else {
2468				continue;
2469			}
2470		}
2471	}
2472	if (need_insert) {
2473		if ((err = noise_remote_enable(peer->p_remote)) != 0)
2474			goto out;
2475		TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
2476		sc->sc_peers_num++;
2477		if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
2478			wg_timers_enable(peer);
2479	}
2480	if (remote != NULL)
2481		noise_remote_put(remote);
2482	return (0);
2483out:
2484	if (need_insert) /* If we fail, only destroy if it was new. */
2485		wg_peer_destroy(peer);
2486	if (remote != NULL)
2487		noise_remote_put(remote);
2488	return (err);
2489}
2490
2491static int
2492wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2493{
2494	uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2495	if_t ifp;
2496	void *nvlpacked;
2497	nvlist_t *nvl;
2498	ssize_t size;
2499	int err;
2500
2501	ifp = sc->sc_ifp;
2502	if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2503		return (EFAULT);
2504
2505	/* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2506	 * there needs to be _some_ limitation. */
2507	if (wgd->wgd_size >= UINT32_MAX / 2)
2508		return (E2BIG);
2509
2510	nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2511
2512	err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2513	if (err)
2514		goto out;
2515	nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2516	if (nvl == NULL) {
2517		err = EBADMSG;
2518		goto out;
2519	}
2520	sx_xlock(&sc->sc_lock);
2521	if (nvlist_exists_bool(nvl, "replace-peers") &&
2522		nvlist_get_bool(nvl, "replace-peers"))
2523		wg_peer_destroy_all(sc);
2524	if (nvlist_exists_number(nvl, "listen-port")) {
2525		uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2526		if (new_port > UINT16_MAX) {
2527			err = EINVAL;
2528			goto out_locked;
2529		}
2530		if (new_port != sc->sc_socket.so_port) {
2531			if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2532				if ((err = wg_socket_init(sc, new_port)) != 0)
2533					goto out_locked;
2534			} else
2535				sc->sc_socket.so_port = new_port;
2536		}
2537	}
2538	if (nvlist_exists_binary(nvl, "private-key")) {
2539		const void *key = nvlist_get_binary(nvl, "private-key", &size);
2540		if (size != WG_KEY_SIZE) {
2541			err = EINVAL;
2542			goto out_locked;
2543		}
2544
2545		if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2546		    timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2547			struct wg_peer *peer;
2548
2549			if (curve25519_generate_public(public, key)) {
2550				/* Peer conflict: remove conflicting peer. */
2551				struct noise_remote *remote;
2552				if ((remote = noise_remote_lookup(sc->sc_local,
2553				    public)) != NULL) {
2554					peer = noise_remote_arg(remote);
2555					wg_peer_destroy(peer);
2556					noise_remote_put(remote);
2557				}
2558			}
2559
2560			/*
2561			 * Set the private key and invalidate all existing
2562			 * handshakes.
2563			 */
2564			/* Note: we might be removing the private key. */
2565			noise_local_private(sc->sc_local, key);
2566			if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2567				cookie_checker_update(&sc->sc_cookie, public);
2568			else
2569				cookie_checker_update(&sc->sc_cookie, NULL);
2570		}
2571	}
2572	if (nvlist_exists_number(nvl, "user-cookie")) {
2573		uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2574		if (user_cookie > UINT32_MAX) {
2575			err = EINVAL;
2576			goto out_locked;
2577		}
2578		err = wg_socket_set_cookie(sc, user_cookie);
2579		if (err)
2580			goto out_locked;
2581	}
2582	if (nvlist_exists_nvlist_array(nvl, "peers")) {
2583		size_t peercount;
2584		const nvlist_t * const*nvl_peers;
2585
2586		nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2587		for (int i = 0; i < peercount; i++) {
2588			err = wg_peer_add(sc, nvl_peers[i]);
2589			if (err != 0)
2590				goto out_locked;
2591		}
2592	}
2593
2594out_locked:
2595	sx_xunlock(&sc->sc_lock);
2596	nvlist_destroy(nvl);
2597out:
2598	zfree(nvlpacked, M_TEMP);
2599	return (err);
2600}
2601
2602static int
2603wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2604{
2605	uint8_t public_key[WG_KEY_SIZE] = { 0 };
2606	uint8_t private_key[WG_KEY_SIZE] = { 0 };
2607	uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2608	nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2609	size_t size, peer_count, aip_count, i, j;
2610	struct wg_timespec64 ts64;
2611	struct wg_peer *peer;
2612	struct wg_aip *aip;
2613	void *packed;
2614	int err = 0;
2615
2616	nvl = nvlist_create(0);
2617	if (!nvl)
2618		return (ENOMEM);
2619
2620	sx_slock(&sc->sc_lock);
2621
2622	if (sc->sc_socket.so_port != 0)
2623		nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2624	if (sc->sc_socket.so_user_cookie != 0)
2625		nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2626	if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2627		nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2628		if (wgc_privileged(sc))
2629			nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2630		explicit_bzero(private_key, sizeof(private_key));
2631	}
2632	peer_count = sc->sc_peers_num;
2633	if (peer_count) {
2634		nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2635		i = 0;
2636		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2637			if (i >= peer_count)
2638				panic("peers changed from under us");
2639
2640			nvl_peers[i++] = nvl_peer = nvlist_create(0);
2641			if (!nvl_peer) {
2642				err = ENOMEM;
2643				goto err_peer;
2644			}
2645
2646			(void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2647			nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2648			if (wgc_privileged(sc))
2649				nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2650			explicit_bzero(preshared_key, sizeof(preshared_key));
2651			if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2652				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2653			else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2654				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2655			wg_timers_get_last_handshake(peer, &ts64);
2656			nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2657			nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2658			nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2659			nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2660
2661			aip_count = peer->p_aips_num;
2662			if (aip_count) {
2663				nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2664				j = 0;
2665				LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2666					if (j >= aip_count)
2667						panic("aips changed from under us");
2668
2669					nvl_aips[j++] = nvl_aip = nvlist_create(0);
2670					if (!nvl_aip) {
2671						err = ENOMEM;
2672						goto err_aip;
2673					}
2674					if (aip->a_af == AF_INET) {
2675						nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2676						nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2677					}
2678#ifdef INET6
2679					else if (aip->a_af == AF_INET6) {
2680						nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2681						nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2682					}
2683#endif
2684				}
2685				nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2686			err_aip:
2687				for (j = 0; j < aip_count; ++j)
2688					nvlist_destroy(nvl_aips[j]);
2689				free(nvl_aips, M_NVLIST);
2690				if (err)
2691					goto err_peer;
2692			}
2693		}
2694		nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2695	err_peer:
2696		for (i = 0; i < peer_count; ++i)
2697			nvlist_destroy(nvl_peers[i]);
2698		free(nvl_peers, M_NVLIST);
2699		if (err) {
2700			sx_sunlock(&sc->sc_lock);
2701			goto err;
2702		}
2703	}
2704	sx_sunlock(&sc->sc_lock);
2705	packed = nvlist_pack(nvl, &size);
2706	if (!packed) {
2707		err = ENOMEM;
2708		goto err;
2709	}
2710	if (!wgd->wgd_size) {
2711		wgd->wgd_size = size;
2712		goto out;
2713	}
2714	if (wgd->wgd_size < size) {
2715		err = ENOSPC;
2716		goto out;
2717	}
2718	err = copyout(packed, wgd->wgd_data, size);
2719	wgd->wgd_size = size;
2720
2721out:
2722	zfree(packed, M_NVLIST);
2723err:
2724	nvlist_destroy(nvl);
2725	return (err);
2726}
2727
2728static int
2729wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2730{
2731	struct wg_data_io *wgd = (struct wg_data_io *)data;
2732	struct ifreq *ifr = (struct ifreq *)data;
2733	struct wg_softc *sc;
2734	int ret = 0;
2735
2736	sx_slock(&wg_sx);
2737	sc = if_getsoftc(ifp);
2738	if (!sc) {
2739		ret = ENXIO;
2740		goto out;
2741	}
2742
2743	switch (cmd) {
2744	case SIOCSWG:
2745		ret = priv_check(curthread, PRIV_NET_WG);
2746		if (ret == 0)
2747			ret = wgc_set(sc, wgd);
2748		break;
2749	case SIOCGWG:
2750		ret = wgc_get(sc, wgd);
2751		break;
2752	/* Interface IOCTLs */
2753	case SIOCSIFADDR:
2754		/*
2755		 * This differs from *BSD norms, but is more uniform with how
2756		 * WireGuard behaves elsewhere.
2757		 */
2758		break;
2759	case SIOCSIFFLAGS:
2760		if (if_getflags(ifp) & IFF_UP)
2761			ret = wg_up(sc);
2762		else
2763			wg_down(sc);
2764		break;
2765	case SIOCSIFMTU:
2766		if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2767			ret = EINVAL;
2768		else
2769			if_setmtu(ifp, ifr->ifr_mtu);
2770		break;
2771	case SIOCADDMULTI:
2772	case SIOCDELMULTI:
2773		break;
2774	case SIOCGTUNFIB:
2775		ifr->ifr_fib = sc->sc_socket.so_fibnum;
2776		break;
2777	case SIOCSTUNFIB:
2778		ret = priv_check(curthread, PRIV_NET_WG);
2779		if (ret)
2780			break;
2781		ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2782		if (ret)
2783			break;
2784		sx_xlock(&sc->sc_lock);
2785		ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2786		sx_xunlock(&sc->sc_lock);
2787		break;
2788	default:
2789		ret = ENOTTY;
2790	}
2791
2792out:
2793	sx_sunlock(&wg_sx);
2794	return (ret);
2795}
2796
2797static int
2798wg_up(struct wg_softc *sc)
2799{
2800	if_t ifp = sc->sc_ifp;
2801	struct wg_peer *peer;
2802	int rc = EBUSY;
2803
2804	sx_xlock(&sc->sc_lock);
2805	/* Jail's being removed, no more wg_up(). */
2806	if ((sc->sc_flags & WGF_DYING) != 0)
2807		goto out;
2808
2809	/* Silent success if we're already running. */
2810	rc = 0;
2811	if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2812		goto out;
2813	if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2814
2815	rc = wg_socket_init(sc, sc->sc_socket.so_port);
2816	if (rc == 0) {
2817		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2818			wg_timers_enable(peer);
2819		if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2820	} else {
2821		if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2822		DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2823	}
2824out:
2825	sx_xunlock(&sc->sc_lock);
2826	return (rc);
2827}
2828
2829static void
2830wg_down(struct wg_softc *sc)
2831{
2832	if_t ifp = sc->sc_ifp;
2833	struct wg_peer *peer;
2834
2835	sx_xlock(&sc->sc_lock);
2836	if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2837		sx_xunlock(&sc->sc_lock);
2838		return;
2839	}
2840	if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2841
2842	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2843		wg_queue_purge(&peer->p_stage_queue);
2844		wg_timers_disable(peer);
2845	}
2846
2847	wg_queue_purge(&sc->sc_handshake_queue);
2848
2849	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2850		noise_remote_handshake_clear(peer->p_remote);
2851		noise_remote_keypairs_clear(peer->p_remote);
2852	}
2853
2854	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2855	wg_socket_uninit(sc);
2856
2857	sx_xunlock(&sc->sc_lock);
2858}
2859
2860static int
2861wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2862    struct ifc_data *ifd, struct ifnet **ifpp)
2863{
2864	struct wg_softc *sc;
2865	if_t ifp;
2866
2867	sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2868
2869	sc->sc_local = noise_local_alloc(sc);
2870
2871	sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2872
2873	sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2874
2875	if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2876		goto free_decrypt;
2877
2878	if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2879		goto free_aip4;
2880
2881	atomic_add_int(&clone_count, 1);
2882	ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2883
2884	sc->sc_ucred = crhold(curthread->td_ucred);
2885	sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2886	sc->sc_socket.so_port = 0;
2887
2888	TAILQ_INIT(&sc->sc_peers);
2889	sc->sc_peers_num = 0;
2890
2891	cookie_checker_init(&sc->sc_cookie);
2892
2893	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2894	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2895
2896	GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
2897	taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
2898	wg_queue_init(&sc->sc_handshake_queue, "hsq");
2899
2900	for (int i = 0; i < mp_ncpus; i++) {
2901		GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
2902		     (gtask_fn_t *)wg_softc_encrypt, sc);
2903		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
2904		GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
2905		    (gtask_fn_t *)wg_softc_decrypt, sc);
2906		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
2907	}
2908
2909	wg_queue_init(&sc->sc_encrypt_parallel, "encp");
2910	wg_queue_init(&sc->sc_decrypt_parallel, "decp");
2911
2912	sx_init(&sc->sc_lock, "wg softc lock");
2913
2914	if_setsoftc(ifp, sc);
2915	if_setcapabilities(ifp, WG_CAPS);
2916	if_setcapenable(ifp, WG_CAPS);
2917	if_initname(ifp, wgname, ifd->unit);
2918
2919	if_setmtu(ifp, DEFAULT_MTU);
2920	if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
2921	if_setinitfn(ifp, wg_init);
2922	if_setreassignfn(ifp, wg_reassign);
2923	if_setqflushfn(ifp, wg_qflush);
2924#ifdef DEV_NETMAP
2925	if_settransmitfn(ifp, wg_transmit);
2926	if_setinputfn(ifp, wg_if_input);
2927#endif
2928	if_setoutputfn(ifp, wg_output);
2929	if_setioctlfn(ifp, wg_ioctl);
2930	if_attach(ifp);
2931	bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
2932#ifdef INET6
2933	ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
2934	ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
2935#endif
2936	sx_xlock(&wg_sx);
2937	LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
2938	sx_xunlock(&wg_sx);
2939	*ifpp = ifp;
2940	return (0);
2941free_aip4:
2942	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
2943	free(sc->sc_aip4, M_RTABLE);
2944free_decrypt:
2945	free(sc->sc_decrypt, M_WG);
2946	free(sc->sc_encrypt, M_WG);
2947	noise_local_free(sc->sc_local, NULL);
2948	free(sc, M_WG);
2949	return (ENOMEM);
2950}
2951
2952static void
2953wg_clone_deferred_free(struct noise_local *l)
2954{
2955	struct wg_softc *sc = noise_local_arg(l);
2956
2957	free(sc, M_WG);
2958	atomic_add_int(&clone_count, -1);
2959}
2960
2961static int
2962wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
2963{
2964	struct wg_softc *sc = if_getsoftc(ifp);
2965	struct ucred *cred;
2966
2967	sx_xlock(&wg_sx);
2968	if_setsoftc(ifp, NULL);
2969	sx_xlock(&sc->sc_lock);
2970	sc->sc_flags |= WGF_DYING;
2971	cred = sc->sc_ucred;
2972	sc->sc_ucred = NULL;
2973	sx_xunlock(&sc->sc_lock);
2974	LIST_REMOVE(sc, sc_entry);
2975	sx_xunlock(&wg_sx);
2976
2977	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2978	CURVNET_SET(if_getvnet(sc->sc_ifp));
2979	if_purgeaddrs(sc->sc_ifp);
2980	CURVNET_RESTORE();
2981
2982	sx_xlock(&sc->sc_lock);
2983	wg_socket_uninit(sc);
2984	sx_xunlock(&sc->sc_lock);
2985
2986	/*
2987	 * No guarantees that all traffic have passed until the epoch has
2988	 * elapsed with the socket closed.
2989	 */
2990	NET_EPOCH_WAIT();
2991
2992	taskqgroup_drain_all(qgroup_wg_tqg);
2993	sx_xlock(&sc->sc_lock);
2994	wg_peer_destroy_all(sc);
2995	NET_EPOCH_DRAIN_CALLBACKS();
2996	sx_xunlock(&sc->sc_lock);
2997	sx_destroy(&sc->sc_lock);
2998	taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
2999	for (int i = 0; i < mp_ncpus; i++) {
3000		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
3001		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3002	}
3003	free(sc->sc_encrypt, M_WG);
3004	free(sc->sc_decrypt, M_WG);
3005	wg_queue_deinit(&sc->sc_handshake_queue);
3006	wg_queue_deinit(&sc->sc_encrypt_parallel);
3007	wg_queue_deinit(&sc->sc_decrypt_parallel);
3008
3009	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3010	RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3011	rn_detachhead((void **)&sc->sc_aip4);
3012	rn_detachhead((void **)&sc->sc_aip6);
3013
3014	cookie_checker_free(&sc->sc_cookie);
3015
3016	if (cred != NULL)
3017		crfree(cred);
3018	bpfdetach(sc->sc_ifp);
3019	if_detach(sc->sc_ifp);
3020	if_free(sc->sc_ifp);
3021
3022	noise_local_free(sc->sc_local, wg_clone_deferred_free);
3023
3024	return (0);
3025}
3026
3027static void
3028wg_qflush(if_t ifp __unused)
3029{
3030}
3031
3032/*
3033 * Privileged information (private-key, preshared-key) are only exported for
3034 * root and jailed root by default.
3035 */
3036static bool
3037wgc_privileged(struct wg_softc *sc)
3038{
3039	struct thread *td;
3040
3041	td = curthread;
3042	return (priv_check(td, PRIV_NET_WG) == 0);
3043}
3044
3045static void
3046wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3047    char *unused __unused)
3048{
3049	struct wg_softc *sc;
3050
3051	sc = if_getsoftc(ifp);
3052	wg_down(sc);
3053}
3054
3055static void
3056wg_init(void *xsc)
3057{
3058	struct wg_softc *sc;
3059
3060	sc = xsc;
3061	wg_up(sc);
3062}
3063
3064static void
3065vnet_wg_init(const void *unused __unused)
3066{
3067	struct if_clone_addreq req = {
3068		.create_f = wg_clone_create,
3069		.destroy_f = wg_clone_destroy,
3070		.flags = IFC_F_AUTOUNIT,
3071	};
3072	V_wg_cloner = ifc_attach_cloner(wgname, &req);
3073}
3074VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3075	     vnet_wg_init, NULL);
3076
3077static void
3078vnet_wg_uninit(const void *unused __unused)
3079{
3080	if (V_wg_cloner)
3081		ifc_detach_cloner(V_wg_cloner);
3082}
3083VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3084	       vnet_wg_uninit, NULL);
3085
3086static int
3087wg_prison_remove(void *obj, void *data __unused)
3088{
3089	const struct prison *pr = obj;
3090	struct wg_softc *sc;
3091
3092	/*
3093	 * Do a pass through all if_wg interfaces and release creds on any from
3094	 * the jail that are supposed to be going away.  This will, in turn, let
3095	 * the jail die so that we don't end up with Schr��dinger's jail.
3096	 */
3097	sx_slock(&wg_sx);
3098	LIST_FOREACH(sc, &wg_list, sc_entry) {
3099		sx_xlock(&sc->sc_lock);
3100		if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3101			struct ucred *cred = sc->sc_ucred;
3102			DPRINTF(sc, "Creating jail exiting\n");
3103			if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3104			wg_socket_uninit(sc);
3105			sc->sc_ucred = NULL;
3106			crfree(cred);
3107			sc->sc_flags |= WGF_DYING;
3108		}
3109		sx_xunlock(&sc->sc_lock);
3110	}
3111	sx_sunlock(&wg_sx);
3112
3113	return (0);
3114}
3115
3116#ifdef SELFTESTS
3117#include "selftest/allowedips.c"
3118static bool wg_run_selftests(void)
3119{
3120	bool ret = true;
3121	ret &= wg_allowedips_selftest();
3122	ret &= noise_counter_selftest();
3123	ret &= cookie_selftest();
3124	return ret;
3125}
3126#else
3127static inline bool wg_run_selftests(void) { return true; }
3128#endif
3129
3130static int
3131wg_module_init(void)
3132{
3133	int ret;
3134	osd_method_t methods[PR_MAXMETHOD] = {
3135		[PR_METHOD_REMOVE] = wg_prison_remove,
3136	};
3137
3138	wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3139	     NULL, NULL, NULL, NULL, 0, 0);
3140
3141	ret = crypto_init();
3142	if (ret != 0)
3143		return (ret);
3144	ret = cookie_init();
3145	if (ret != 0)
3146		return (ret);
3147
3148	wg_osd_jail_slot = osd_jail_register(NULL, methods);
3149
3150	if (!wg_run_selftests())
3151		return (ENOTRECOVERABLE);
3152
3153	return (0);
3154}
3155
3156static void
3157wg_module_deinit(void)
3158{
3159	VNET_ITERATOR_DECL(vnet_iter);
3160	VNET_LIST_RLOCK();
3161	VNET_FOREACH(vnet_iter) {
3162		struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3163		if (clone) {
3164			ifc_detach_cloner(clone);
3165			VNET_VNET(vnet_iter, wg_cloner) = NULL;
3166		}
3167	}
3168	VNET_LIST_RUNLOCK();
3169	NET_EPOCH_WAIT();
3170	MPASS(LIST_EMPTY(&wg_list));
3171	if (wg_osd_jail_slot != 0)
3172		osd_jail_deregister(wg_osd_jail_slot);
3173	cookie_deinit();
3174	crypto_deinit();
3175	if (wg_packet_zone != NULL)
3176		uma_zdestroy(wg_packet_zone);
3177}
3178
3179static int
3180wg_module_event_handler(module_t mod, int what, void *arg)
3181{
3182	switch (what) {
3183		case MOD_LOAD:
3184			return wg_module_init();
3185		case MOD_UNLOAD:
3186			wg_module_deinit();
3187			break;
3188		default:
3189			return (EOPNOTSUPP);
3190	}
3191	return (0);
3192}
3193
3194static moduledata_t wg_moduledata = {
3195	"if_wg",
3196	wg_module_event_handler,
3197	NULL
3198};
3199
3200DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3201MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3202MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3203