1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 */
5
6#include "queueing.h"
7#include "device.h"
8#include "peer.h"
9#include "timers.h"
10#include "messages.h"
11#include "cookie.h"
12#include "socket.h"
13
14#include <linux/ip.h>
15#include <linux/ipv6.h>
16#include <linux/udp.h>
17#include <net/ip_tunnels.h>
18
19/* Must be called with bh disabled. */
20static void update_rx_stats(struct wg_peer *peer, size_t len)
21{
22	dev_sw_netstats_rx_add(peer->device->dev, len);
23	peer->rx_bytes += len;
24}
25
26#define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
27
28static size_t validate_header_len(struct sk_buff *skb)
29{
30	if (unlikely(skb->len < sizeof(struct message_header)))
31		return 0;
32	if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) &&
33	    skb->len >= MESSAGE_MINIMUM_LENGTH)
34		return sizeof(struct message_data);
35	if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) &&
36	    skb->len == sizeof(struct message_handshake_initiation))
37		return sizeof(struct message_handshake_initiation);
38	if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) &&
39	    skb->len == sizeof(struct message_handshake_response))
40		return sizeof(struct message_handshake_response);
41	if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) &&
42	    skb->len == sizeof(struct message_handshake_cookie))
43		return sizeof(struct message_handshake_cookie);
44	return 0;
45}
46
47static int prepare_skb_header(struct sk_buff *skb, struct wg_device *wg)
48{
49	size_t data_offset, data_len, header_len;
50	struct udphdr *udp;
51
52	if (unlikely(!wg_check_packet_protocol(skb) ||
53		     skb_transport_header(skb) < skb->head ||
54		     (skb_transport_header(skb) + sizeof(struct udphdr)) >
55			     skb_tail_pointer(skb)))
56		return -EINVAL; /* Bogus IP header */
57	udp = udp_hdr(skb);
58	data_offset = (u8 *)udp - skb->data;
59	if (unlikely(data_offset > U16_MAX ||
60		     data_offset + sizeof(struct udphdr) > skb->len))
61		/* Packet has offset at impossible location or isn't big enough
62		 * to have UDP fields.
63		 */
64		return -EINVAL;
65	data_len = ntohs(udp->len);
66	if (unlikely(data_len < sizeof(struct udphdr) ||
67		     data_len > skb->len - data_offset))
68		/* UDP packet is reporting too small of a size or lying about
69		 * its size.
70		 */
71		return -EINVAL;
72	data_len -= sizeof(struct udphdr);
73	data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
74	if (unlikely(!pskb_may_pull(skb,
75				data_offset + sizeof(struct message_header)) ||
76		     pskb_trim(skb, data_len + data_offset) < 0))
77		return -EINVAL;
78	skb_pull(skb, data_offset);
79	if (unlikely(skb->len != data_len))
80		/* Final len does not agree with calculated len */
81		return -EINVAL;
82	header_len = validate_header_len(skb);
83	if (unlikely(!header_len))
84		return -EINVAL;
85	__skb_push(skb, data_offset);
86	if (unlikely(!pskb_may_pull(skb, data_offset + header_len)))
87		return -EINVAL;
88	__skb_pull(skb, data_offset);
89	return 0;
90}
91
92static void wg_receive_handshake_packet(struct wg_device *wg,
93					struct sk_buff *skb)
94{
95	enum cookie_mac_state mac_state;
96	struct wg_peer *peer = NULL;
97	/* This is global, so that our load calculation applies to the whole
98	 * system. We don't care about races with it at all.
99	 */
100	static u64 last_under_load;
101	bool packet_needs_cookie;
102	bool under_load;
103
104	if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) {
105		net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
106					wg->dev->name, skb);
107		wg_cookie_message_consume(
108			(struct message_handshake_cookie *)skb->data, wg);
109		return;
110	}
111
112	under_load = atomic_read(&wg->handshake_queue_len) >=
113			MAX_QUEUED_INCOMING_HANDSHAKES / 8;
114	if (under_load) {
115		last_under_load = ktime_get_coarse_boottime_ns();
116	} else if (last_under_load) {
117		under_load = !wg_birthdate_has_expired(last_under_load, 1);
118		if (!under_load)
119			last_under_load = 0;
120	}
121	mac_state = wg_cookie_validate_packet(&wg->cookie_checker, skb,
122					      under_load);
123	if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) ||
124	    (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) {
125		packet_needs_cookie = false;
126	} else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE) {
127		packet_needs_cookie = true;
128	} else {
129		net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
130					wg->dev->name, skb);
131		return;
132	}
133
134	switch (SKB_TYPE_LE32(skb)) {
135	case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): {
136		struct message_handshake_initiation *message =
137			(struct message_handshake_initiation *)skb->data;
138
139		if (packet_needs_cookie) {
140			wg_packet_send_handshake_cookie(wg, skb,
141							message->sender_index);
142			return;
143		}
144		peer = wg_noise_handshake_consume_initiation(message, wg);
145		if (unlikely(!peer)) {
146			net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
147						wg->dev->name, skb);
148			return;
149		}
150		wg_socket_set_peer_endpoint_from_skb(peer, skb);
151		net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
152				    wg->dev->name, peer->internal_id,
153				    &peer->endpoint.addr);
154		wg_packet_send_handshake_response(peer);
155		break;
156	}
157	case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): {
158		struct message_handshake_response *message =
159			(struct message_handshake_response *)skb->data;
160
161		if (packet_needs_cookie) {
162			wg_packet_send_handshake_cookie(wg, skb,
163							message->sender_index);
164			return;
165		}
166		peer = wg_noise_handshake_consume_response(message, wg);
167		if (unlikely(!peer)) {
168			net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
169						wg->dev->name, skb);
170			return;
171		}
172		wg_socket_set_peer_endpoint_from_skb(peer, skb);
173		net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
174				    wg->dev->name, peer->internal_id,
175				    &peer->endpoint.addr);
176		if (wg_noise_handshake_begin_session(&peer->handshake,
177						     &peer->keypairs)) {
178			wg_timers_session_derived(peer);
179			wg_timers_handshake_complete(peer);
180			/* Calling this function will either send any existing
181			 * packets in the queue and not send a keepalive, which
182			 * is the best case, Or, if there's nothing in the
183			 * queue, it will send a keepalive, in order to give
184			 * immediate confirmation of the session.
185			 */
186			wg_packet_send_keepalive(peer);
187		}
188		break;
189	}
190	}
191
192	if (unlikely(!peer)) {
193		WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
194		return;
195	}
196
197	local_bh_disable();
198	update_rx_stats(peer, skb->len);
199	local_bh_enable();
200
201	wg_timers_any_authenticated_packet_received(peer);
202	wg_timers_any_authenticated_packet_traversal(peer);
203	wg_peer_put(peer);
204}
205
206void wg_packet_handshake_receive_worker(struct work_struct *work)
207{
208	struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
209	struct wg_device *wg = container_of(queue, struct wg_device, handshake_queue);
210	struct sk_buff *skb;
211
212	while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
213		wg_receive_handshake_packet(wg, skb);
214		dev_kfree_skb(skb);
215		atomic_dec(&wg->handshake_queue_len);
216		cond_resched();
217	}
218}
219
220static void keep_key_fresh(struct wg_peer *peer)
221{
222	struct noise_keypair *keypair;
223	bool send;
224
225	if (peer->sent_lastminute_handshake)
226		return;
227
228	rcu_read_lock_bh();
229	keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
230	send = keypair && READ_ONCE(keypair->sending.is_valid) &&
231	       keypair->i_am_the_initiator &&
232	       wg_birthdate_has_expired(keypair->sending.birthdate,
233			REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT);
234	rcu_read_unlock_bh();
235
236	if (unlikely(send)) {
237		peer->sent_lastminute_handshake = true;
238		wg_packet_send_queued_handshake_initiation(peer, false);
239	}
240}
241
242static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
243{
244	struct scatterlist sg[MAX_SKB_FRAGS + 8];
245	struct sk_buff *trailer;
246	unsigned int offset;
247	int num_frags;
248
249	if (unlikely(!keypair))
250		return false;
251
252	if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
253		  wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
254		  READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) {
255		WRITE_ONCE(keypair->receiving.is_valid, false);
256		return false;
257	}
258
259	PACKET_CB(skb)->nonce =
260		le64_to_cpu(((struct message_data *)skb->data)->counter);
261
262	/* We ensure that the network header is part of the packet before we
263	 * call skb_cow_data, so that there's no chance that data is removed
264	 * from the skb, so that later we can extract the original endpoint.
265	 */
266	offset = -skb_network_offset(skb);
267	skb_push(skb, offset);
268	num_frags = skb_cow_data(skb, 0, &trailer);
269	offset += sizeof(struct message_data);
270	skb_pull(skb, offset);
271	if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
272		return false;
273
274	sg_init_table(sg, num_frags);
275	if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
276		return false;
277
278	if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
279					         PACKET_CB(skb)->nonce,
280						 keypair->receiving.key))
281		return false;
282
283	/* Another ugly situation of pushing and pulling the header so as to
284	 * keep endpoint information intact.
285	 */
286	skb_push(skb, offset);
287	if (pskb_trim(skb, skb->len - noise_encrypted_len(0)))
288		return false;
289	skb_pull(skb, offset);
290
291	return true;
292}
293
294/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
295static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
296{
297	unsigned long index, index_current, top, i;
298	bool ret = false;
299
300	spin_lock_bh(&counter->lock);
301
302	if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
303		     their_counter >= REJECT_AFTER_MESSAGES))
304		goto out;
305
306	++their_counter;
307
308	if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
309		     counter->counter))
310		goto out;
311
312	index = their_counter >> ilog2(BITS_PER_LONG);
313
314	if (likely(their_counter > counter->counter)) {
315		index_current = counter->counter >> ilog2(BITS_PER_LONG);
316		top = min_t(unsigned long, index - index_current,
317			    COUNTER_BITS_TOTAL / BITS_PER_LONG);
318		for (i = 1; i <= top; ++i)
319			counter->backtrack[(i + index_current) &
320				((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
321		WRITE_ONCE(counter->counter, their_counter);
322	}
323
324	index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
325	ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
326				&counter->backtrack[index]);
327
328out:
329	spin_unlock_bh(&counter->lock);
330	return ret;
331}
332
333#include "selftest/counter.c"
334
335static void wg_packet_consume_data_done(struct wg_peer *peer,
336					struct sk_buff *skb,
337					struct endpoint *endpoint)
338{
339	struct net_device *dev = peer->device->dev;
340	unsigned int len, len_before_trim;
341	struct wg_peer *routed_peer;
342
343	wg_socket_set_peer_endpoint(peer, endpoint);
344
345	if (unlikely(wg_noise_received_with_keypair(&peer->keypairs,
346						    PACKET_CB(skb)->keypair))) {
347		wg_timers_handshake_complete(peer);
348		wg_packet_send_staged_packets(peer);
349	}
350
351	keep_key_fresh(peer);
352
353	wg_timers_any_authenticated_packet_received(peer);
354	wg_timers_any_authenticated_packet_traversal(peer);
355
356	/* A packet with length 0 is a keepalive packet */
357	if (unlikely(!skb->len)) {
358		update_rx_stats(peer, message_data_len(0));
359		net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
360				    dev->name, peer->internal_id,
361				    &peer->endpoint.addr);
362		goto packet_processed;
363	}
364
365	wg_timers_data_received(peer);
366
367	if (unlikely(skb_network_header(skb) < skb->head))
368		goto dishonest_packet_size;
369	if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) &&
370		       (ip_hdr(skb)->version == 4 ||
371			(ip_hdr(skb)->version == 6 &&
372			 pskb_network_may_pull(skb, sizeof(struct ipv6hdr)))))))
373		goto dishonest_packet_type;
374
375	skb->dev = dev;
376	/* We've already verified the Poly1305 auth tag, which means this packet
377	 * was not modified in transit. We can therefore tell the networking
378	 * stack that all checksums of every layer of encapsulation have already
379	 * been checked "by the hardware" and therefore is unnecessary to check
380	 * again in software.
381	 */
382	skb->ip_summed = CHECKSUM_UNNECESSARY;
383	skb->csum_level = ~0; /* All levels */
384	skb->protocol = ip_tunnel_parse_protocol(skb);
385	if (skb->protocol == htons(ETH_P_IP)) {
386		len = ntohs(ip_hdr(skb)->tot_len);
387		if (unlikely(len < sizeof(struct iphdr)))
388			goto dishonest_packet_size;
389		INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos);
390	} else if (skb->protocol == htons(ETH_P_IPV6)) {
391		len = ntohs(ipv6_hdr(skb)->payload_len) +
392		      sizeof(struct ipv6hdr);
393		INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb)));
394	} else {
395		goto dishonest_packet_type;
396	}
397
398	if (unlikely(len > skb->len))
399		goto dishonest_packet_size;
400	len_before_trim = skb->len;
401	if (unlikely(pskb_trim(skb, len)))
402		goto packet_processed;
403
404	routed_peer = wg_allowedips_lookup_src(&peer->device->peer_allowedips,
405					       skb);
406	wg_peer_put(routed_peer); /* We don't need the extra reference. */
407
408	if (unlikely(routed_peer != peer))
409		goto dishonest_packet_peer;
410
411	napi_gro_receive(&peer->napi, skb);
412	update_rx_stats(peer, message_data_len(len_before_trim));
413	return;
414
415dishonest_packet_peer:
416	net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
417				dev->name, skb, peer->internal_id,
418				&peer->endpoint.addr);
419	DEV_STATS_INC(dev, rx_errors);
420	DEV_STATS_INC(dev, rx_frame_errors);
421	goto packet_processed;
422dishonest_packet_type:
423	net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
424			    dev->name, peer->internal_id, &peer->endpoint.addr);
425	DEV_STATS_INC(dev, rx_errors);
426	DEV_STATS_INC(dev, rx_frame_errors);
427	goto packet_processed;
428dishonest_packet_size:
429	net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
430			    dev->name, peer->internal_id, &peer->endpoint.addr);
431	DEV_STATS_INC(dev, rx_errors);
432	DEV_STATS_INC(dev, rx_length_errors);
433	goto packet_processed;
434packet_processed:
435	dev_kfree_skb(skb);
436}
437
438int wg_packet_rx_poll(struct napi_struct *napi, int budget)
439{
440	struct wg_peer *peer = container_of(napi, struct wg_peer, napi);
441	struct noise_keypair *keypair;
442	struct endpoint endpoint;
443	enum packet_state state;
444	struct sk_buff *skb;
445	int work_done = 0;
446	bool free;
447
448	if (unlikely(budget <= 0))
449		return 0;
450
451	while ((skb = wg_prev_queue_peek(&peer->rx_queue)) != NULL &&
452	       (state = atomic_read_acquire(&PACKET_CB(skb)->state)) !=
453		       PACKET_STATE_UNCRYPTED) {
454		wg_prev_queue_drop_peeked(&peer->rx_queue);
455		keypair = PACKET_CB(skb)->keypair;
456		free = true;
457
458		if (unlikely(state != PACKET_STATE_CRYPTED))
459			goto next;
460
461		if (unlikely(!counter_validate(&keypair->receiving_counter,
462					       PACKET_CB(skb)->nonce))) {
463			net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
464					    peer->device->dev->name,
465					    PACKET_CB(skb)->nonce,
466					    READ_ONCE(keypair->receiving_counter.counter));
467			goto next;
468		}
469
470		if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
471			goto next;
472
473		wg_reset_packet(skb, false);
474		wg_packet_consume_data_done(peer, skb, &endpoint);
475		free = false;
476
477next:
478		wg_noise_keypair_put(keypair, false);
479		wg_peer_put(peer);
480		if (unlikely(free))
481			dev_kfree_skb(skb);
482
483		if (++work_done >= budget)
484			break;
485	}
486
487	if (work_done < budget)
488		napi_complete_done(napi, work_done);
489
490	return work_done;
491}
492
493void wg_packet_decrypt_worker(struct work_struct *work)
494{
495	struct crypt_queue *queue = container_of(work, struct multicore_worker,
496						 work)->ptr;
497	struct sk_buff *skb;
498
499	while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
500		enum packet_state state =
501			likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
502				PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
503		wg_queue_enqueue_per_peer_rx(skb, state);
504		if (need_resched())
505			cond_resched();
506	}
507}
508
509static void wg_packet_consume_data(struct wg_device *wg, struct sk_buff *skb)
510{
511	__le32 idx = ((struct message_data *)skb->data)->key_idx;
512	struct wg_peer *peer = NULL;
513	int ret;
514
515	rcu_read_lock_bh();
516	PACKET_CB(skb)->keypair =
517		(struct noise_keypair *)wg_index_hashtable_lookup(
518			wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx,
519			&peer);
520	if (unlikely(!wg_noise_keypair_get(PACKET_CB(skb)->keypair)))
521		goto err_keypair;
522
523	if (unlikely(READ_ONCE(peer->is_dead)))
524		goto err;
525
526	ret = wg_queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb,
527						   wg->packet_crypt_wq);
528	if (unlikely(ret == -EPIPE))
529		wg_queue_enqueue_per_peer_rx(skb, PACKET_STATE_DEAD);
530	if (likely(!ret || ret == -EPIPE)) {
531		rcu_read_unlock_bh();
532		return;
533	}
534err:
535	wg_noise_keypair_put(PACKET_CB(skb)->keypair, false);
536err_keypair:
537	rcu_read_unlock_bh();
538	wg_peer_put(peer);
539	dev_kfree_skb(skb);
540}
541
542void wg_packet_receive(struct wg_device *wg, struct sk_buff *skb)
543{
544	if (unlikely(prepare_skb_header(skb, wg) < 0))
545		goto err;
546	switch (SKB_TYPE_LE32(skb)) {
547	case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION):
548	case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE):
549	case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): {
550		int cpu, ret = -EBUSY;
551
552		if (unlikely(!rng_is_initialized()))
553			goto drop;
554		if (atomic_read(&wg->handshake_queue_len) > MAX_QUEUED_INCOMING_HANDSHAKES / 2) {
555			if (spin_trylock_bh(&wg->handshake_queue.ring.producer_lock)) {
556				ret = __ptr_ring_produce(&wg->handshake_queue.ring, skb);
557				spin_unlock_bh(&wg->handshake_queue.ring.producer_lock);
558			}
559		} else
560			ret = ptr_ring_produce_bh(&wg->handshake_queue.ring, skb);
561		if (ret) {
562	drop:
563			net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
564						wg->dev->name, skb);
565			goto err;
566		}
567		atomic_inc(&wg->handshake_queue_len);
568		cpu = wg_cpumask_next_online(&wg->handshake_queue.last_cpu);
569		/* Queues up a call to packet_process_queued_handshake_packets(skb): */
570		queue_work_on(cpu, wg->handshake_receive_wq,
571			      &per_cpu_ptr(wg->handshake_queue.worker, cpu)->work);
572		break;
573	}
574	case cpu_to_le32(MESSAGE_DATA):
575		PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
576		wg_packet_consume_data(wg, skb);
577		break;
578	default:
579		WARN(1, "Non-exhaustive parsing of packet header lead to unknown packet type!\n");
580		goto err;
581	}
582	return;
583
584err:
585	dev_kfree_skb(skb);
586}
587