1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Selftest that verifies that incomping ICMPs are ignored,
4 * the TCP connection stays alive, no hard or soft errors get reported
5 * to the usespace and the counter for ignored ICMPs is updated.
6 *
7 * RFC5925, 7.8:
8 * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4
9 * messages of Type 3 (destination unreachable), Codes 2-4 (protocol
10 * unreachable, port unreachable, and fragmentation needed -- ���hard
11 * errors���), and ICMPv6 Type 1 (destination unreachable), Code 1
12 * (administratively prohibited) and Code 4 (port unreachable) intended
13 * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN-
14 * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs.
15 *
16 * Author: Dmitry Safonov <dima@arista.com>
17 */
18#include <inttypes.h>
19#include <linux/icmp.h>
20#include <linux/icmpv6.h>
21#include <linux/ipv6.h>
22#include <netinet/in.h>
23#include <netinet/ip.h>
24#include <sys/socket.h>
25#include "aolib.h"
26#include "../../../../include/linux/compiler.h"
27
28const size_t packets_nr = 20;
29const size_t packet_size = 100;
30const char *tcpao_icmps	= "TCPAODroppedIcmps";
31
32#ifdef IPV6_TEST
33const char *dst_unreach	= "Icmp6InDestUnreachs";
34const int sk_ip_level	= SOL_IPV6;
35const int sk_recverr	= IPV6_RECVERR;
36#else
37const char *dst_unreach	= "InDestUnreachs";
38const int sk_ip_level	= SOL_IP;
39const int sk_recverr	= IP_RECVERR;
40#endif
41
42/* Server is expected to fail with hard error if ::accept_icmp is set */
43#ifdef TEST_ICMPS_ACCEPT
44# define test_icmps_fail test_ok
45# define test_icmps_ok test_fail
46#else
47# define test_icmps_fail test_fail
48# define test_icmps_ok test_ok
49#endif
50
51static void serve_interfered(int sk)
52{
53	ssize_t test_quota = packet_size * packets_nr * 10;
54	uint64_t dest_unreach_a, dest_unreach_b;
55	uint64_t icmp_ignored_a, icmp_ignored_b;
56	struct tcp_ao_counters ao_cnt1, ao_cnt2;
57	bool counter_not_found;
58	struct netstat *ns_after, *ns_before;
59	ssize_t bytes;
60
61	ns_before = netstat_read();
62	dest_unreach_a = netstat_get(ns_before, dst_unreach, NULL);
63	icmp_ignored_a = netstat_get(ns_before, tcpao_icmps, NULL);
64	if (test_get_tcp_ao_counters(sk, &ao_cnt1))
65		test_error("test_get_tcp_ao_counters()");
66	bytes = test_server_run(sk, test_quota, 0);
67	ns_after = netstat_read();
68	netstat_print_diff(ns_before, ns_after);
69	dest_unreach_b = netstat_get(ns_after, dst_unreach, NULL);
70	icmp_ignored_b = netstat_get(ns_after, tcpao_icmps,
71					&counter_not_found);
72	if (test_get_tcp_ao_counters(sk, &ao_cnt2))
73		test_error("test_get_tcp_ao_counters()");
74
75	netstat_free(ns_before);
76	netstat_free(ns_after);
77
78	if (dest_unreach_a >= dest_unreach_b) {
79		test_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
80				dst_unreach, dest_unreach_a, dest_unreach_b);
81		return;
82	}
83	test_ok("%s delivered %" PRIu64,
84		dst_unreach, dest_unreach_b - dest_unreach_a);
85	if (bytes < 0)
86		test_icmps_fail("Server failed with %zd: %s", bytes, strerrordesc_np(-bytes));
87	else
88		test_icmps_ok("Server survived %zd bytes of traffic", test_quota);
89	if (counter_not_found) {
90		test_fail("Not found %s counter", tcpao_icmps);
91		return;
92	}
93#ifdef TEST_ICMPS_ACCEPT
94	test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD);
95#else
96	test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP);
97#endif
98	if (icmp_ignored_a >= icmp_ignored_b) {
99		test_icmps_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
100				tcpao_icmps, icmp_ignored_a, icmp_ignored_b);
101		return;
102	}
103	test_icmps_ok("ICMPs ignored %" PRIu64, icmp_ignored_b - icmp_ignored_a);
104}
105
106static void *server_fn(void *arg)
107{
108	int val, sk, lsk;
109	bool accept_icmps = false;
110
111	lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
112
113#ifdef TEST_ICMPS_ACCEPT
114	accept_icmps = true;
115#endif
116
117	if (test_set_ao_flags(lsk, false, accept_icmps))
118		test_error("setsockopt(TCP_AO_INFO)");
119
120	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
121		test_error("setsockopt(TCP_AO_ADD_KEY)");
122	synchronize_threads();
123
124	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
125		test_error("test_wait_fd()");
126
127	sk = accept(lsk, NULL, NULL);
128	if (sk < 0)
129		test_error("accept()");
130
131	/* Fail on hard ip errors, such as dest unreachable (RFC1122) */
132	val = 1;
133	if (setsockopt(sk, sk_ip_level, sk_recverr, &val, sizeof(val)))
134		test_error("setsockopt()");
135
136	synchronize_threads();
137
138	serve_interfered(sk);
139	return NULL;
140}
141
142static size_t packets_sent;
143static size_t icmps_sent;
144
145static uint32_t checksum4_nofold(void *data, size_t len, uint32_t sum)
146{
147	uint16_t *words = data;
148	size_t i;
149
150	for (i = 0; i < len / sizeof(uint16_t); i++)
151		sum += words[i];
152	if (len & 1)
153		sum += ((char *)data)[len - 1];
154	return sum;
155}
156
157static uint16_t checksum4_fold(void *data, size_t len, uint32_t sum)
158{
159	sum = checksum4_nofold(data, len, sum);
160	while (sum > 0xFFFF)
161		sum = (sum & 0xFFFF) + (sum >> 16);
162	return ~sum;
163}
164
165static void set_ip4hdr(struct iphdr *iph, size_t packet_len, int proto,
166		struct sockaddr_in *src, struct sockaddr_in *dst)
167{
168	iph->version	= 4;
169	iph->ihl	= 5;
170	iph->tos	= 0;
171	iph->tot_len	= htons(packet_len);
172	iph->ttl	= 2;
173	iph->protocol	= proto;
174	iph->saddr	= src->sin_addr.s_addr;
175	iph->daddr	= dst->sin_addr.s_addr;
176	iph->check	= checksum4_fold((void *)iph, iph->ihl << 1, 0);
177}
178
179static void icmp_interfere4(uint8_t type, uint8_t code, uint32_t rcv_nxt,
180		struct sockaddr_in *src, struct sockaddr_in *dst)
181{
182	int sk = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
183	struct {
184		struct iphdr iph;
185		struct icmphdr icmph;
186		struct iphdr iphe;
187		struct {
188			uint16_t sport;
189			uint16_t dport;
190			uint32_t seq;
191		} tcph;
192	} packet = {};
193	size_t packet_len;
194	ssize_t bytes;
195
196	if (sk < 0)
197		test_error("socket(AF_INET, SOCK_RAW, IPPROTO_RAW)");
198
199	packet_len = sizeof(packet);
200	set_ip4hdr(&packet.iph, packet_len, IPPROTO_ICMP, src, dst);
201
202	packet.icmph.type = type;
203	packet.icmph.code = code;
204	if (code == ICMP_FRAG_NEEDED) {
205		randomize_buffer(&packet.icmph.un.frag.mtu,
206				sizeof(packet.icmph.un.frag.mtu));
207	}
208
209	packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
210	set_ip4hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
211
212	packet.tcph.sport = dst->sin_port;
213	packet.tcph.dport = src->sin_port;
214	packet.tcph.seq = htonl(rcv_nxt);
215
216	packet_len = sizeof(packet) - sizeof(packet.iph);
217	packet.icmph.checksum = checksum4_fold((void *)&packet.icmph,
218						packet_len, 0);
219
220	bytes = sendto(sk, &packet, sizeof(packet), 0,
221		       (struct sockaddr *)dst, sizeof(*dst));
222	if (bytes != sizeof(packet))
223		test_error("send(): %zd", bytes);
224	icmps_sent++;
225
226	close(sk);
227}
228
229static void set_ip6hdr(struct ipv6hdr *iph, size_t packet_len, int proto,
230		struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
231{
232	iph->version		= 6;
233	iph->payload_len	= htons(packet_len);
234	iph->nexthdr		= proto;
235	iph->hop_limit		= 2;
236	iph->saddr		= src->sin6_addr;
237	iph->daddr		= dst->sin6_addr;
238}
239
240static inline uint16_t csum_fold(uint32_t csum)
241{
242	uint32_t sum = csum;
243
244	sum = (sum & 0xffff) + (sum >> 16);
245	sum = (sum & 0xffff) + (sum >> 16);
246	return (uint16_t)~sum;
247}
248
249static inline uint32_t csum_add(uint32_t csum, uint32_t addend)
250{
251	uint32_t res = csum;
252
253	res += addend;
254	return res + (res < addend);
255}
256
257noinline uint32_t checksum6_nofold(void *data, size_t len, uint32_t sum)
258{
259	uint16_t *words = data;
260	size_t i;
261
262	for (i = 0; i < len / sizeof(uint16_t); i++)
263		sum = csum_add(sum, words[i]);
264	if (len & 1)
265		sum = csum_add(sum, ((char *)data)[len - 1]);
266	return sum;
267}
268
269noinline uint16_t icmp6_checksum(struct sockaddr_in6 *src,
270				 struct sockaddr_in6 *dst,
271				 void *ptr, size_t len, uint8_t proto)
272{
273	struct {
274		struct in6_addr saddr;
275		struct in6_addr daddr;
276		uint32_t payload_len;
277		uint8_t zero[3];
278		uint8_t nexthdr;
279	} pseudo_header = {};
280	uint32_t sum;
281
282	pseudo_header.saddr		= src->sin6_addr;
283	pseudo_header.daddr		= dst->sin6_addr;
284	pseudo_header.payload_len	= htonl(len);
285	pseudo_header.nexthdr		= proto;
286
287	sum = checksum6_nofold(&pseudo_header, sizeof(pseudo_header), 0);
288	sum = checksum6_nofold(ptr, len, sum);
289
290	return csum_fold(sum);
291}
292
293static void icmp6_interfere(int type, int code, uint32_t rcv_nxt,
294		struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
295{
296	int sk = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW);
297	struct sockaddr_in6 dst_raw = *dst;
298	struct {
299		struct ipv6hdr iph;
300		struct icmp6hdr icmph;
301		struct ipv6hdr iphe;
302		struct {
303			uint16_t sport;
304			uint16_t dport;
305			uint32_t seq;
306		} tcph;
307	} packet = {};
308	size_t packet_len;
309	ssize_t bytes;
310
311
312	if (sk < 0)
313		test_error("socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)");
314
315	packet_len = sizeof(packet) - sizeof(packet.iph);
316	set_ip6hdr(&packet.iph, packet_len, IPPROTO_ICMPV6, src, dst);
317
318	packet.icmph.icmp6_type = type;
319	packet.icmph.icmp6_code = code;
320
321	packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
322	set_ip6hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
323
324	packet.tcph.sport = dst->sin6_port;
325	packet.tcph.dport = src->sin6_port;
326	packet.tcph.seq = htonl(rcv_nxt);
327
328	packet_len = sizeof(packet) - sizeof(packet.iph);
329
330	packet.icmph.icmp6_cksum = icmp6_checksum(src, dst,
331			(void *)&packet.icmph, packet_len, IPPROTO_ICMPV6);
332
333	dst_raw.sin6_port = htons(IPPROTO_RAW);
334	bytes = sendto(sk, &packet, sizeof(packet), 0,
335		       (struct sockaddr *)&dst_raw, sizeof(dst_raw));
336	if (bytes != sizeof(packet))
337		test_error("send(): %zd", bytes);
338	icmps_sent++;
339
340	close(sk);
341}
342
343static uint32_t get_rcv_nxt(int sk)
344{
345	int val = TCP_REPAIR_ON;
346	uint32_t ret;
347	socklen_t sz = sizeof(ret);
348
349	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
350		test_error("setsockopt(TCP_REPAIR)");
351	val = TCP_RECV_QUEUE;
352	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &val, sizeof(val)))
353		test_error("setsockopt(TCP_REPAIR_QUEUE)");
354	if (getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &ret, &sz))
355		test_error("getsockopt(TCP_QUEUE_SEQ)");
356	val = TCP_REPAIR_OFF_NO_WP;
357	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
358		test_error("setsockopt(TCP_REPAIR)");
359	return ret;
360}
361
362static void icmp_interfere(const size_t nr, uint32_t rcv_nxt, void *src, void *dst)
363{
364	struct sockaddr_in *saddr4 = src;
365	struct sockaddr_in *daddr4 = dst;
366	struct sockaddr_in6 *saddr6 = src;
367	struct sockaddr_in6 *daddr6 = dst;
368	size_t i;
369
370	if (saddr4->sin_family != daddr4->sin_family)
371		test_error("Different address families");
372
373	for (i = 0; i < nr; i++) {
374		if (saddr4->sin_family == AF_INET) {
375			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PROT_UNREACH,
376					rcv_nxt, saddr4, daddr4);
377			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PORT_UNREACH,
378					rcv_nxt, saddr4, daddr4);
379			icmp_interfere4(ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED,
380					rcv_nxt, saddr4, daddr4);
381			icmps_sent += 3;
382		} else if (saddr4->sin_family == AF_INET6) {
383			icmp6_interfere(ICMPV6_DEST_UNREACH,
384					ICMPV6_ADM_PROHIBITED,
385					rcv_nxt, saddr6, daddr6);
386			icmp6_interfere(ICMPV6_DEST_UNREACH,
387					ICMPV6_PORT_UNREACH,
388					rcv_nxt, saddr6, daddr6);
389			icmps_sent += 2;
390		} else {
391			test_error("Not ip address family");
392		}
393	}
394}
395
396static void send_interfered(int sk)
397{
398	const unsigned int timeout = TEST_TIMEOUT_SEC;
399	struct sockaddr_in6 src, dst;
400	socklen_t addr_sz;
401
402	addr_sz = sizeof(src);
403	if (getsockname(sk, &src, &addr_sz))
404		test_error("getsockname()");
405	addr_sz = sizeof(dst);
406	if (getpeername(sk, &dst, &addr_sz))
407		test_error("getpeername()");
408
409	while (1) {
410		uint32_t rcv_nxt;
411
412		if (test_client_verify(sk, packet_size, packets_nr, timeout)) {
413			test_fail("client: connection is broken");
414			return;
415		}
416		packets_sent += packets_nr;
417		rcv_nxt = get_rcv_nxt(sk);
418		icmp_interfere(packets_nr, rcv_nxt, (void *)&src, (void *)&dst);
419	}
420}
421
422static void *client_fn(void *arg)
423{
424	int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
425
426	if (sk < 0)
427		test_error("socket()");
428
429	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
430		test_error("setsockopt(TCP_AO_ADD_KEY)");
431
432	synchronize_threads();
433	if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
434		test_error("failed to connect()");
435	synchronize_threads();
436
437	send_interfered(sk);
438
439	/* Not expecting client to quit */
440	test_fail("client disconnected");
441
442	return NULL;
443}
444
445int main(int argc, char *argv[])
446{
447	test_init(3, server_fn, client_fn);
448	return 0;
449}
450