1// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2// Copyright (c) 2020 Cloudflare
3
4#define _GNU_SOURCE
5
6#include <arpa/inet.h>
7#include <string.h>
8
9#include <linux/pkt_cls.h>
10#include <netinet/tcp.h>
11
12#include <test_progs.h>
13
14#include "progs/test_cls_redirect.h"
15#include "test_cls_redirect.skel.h"
16#include "test_cls_redirect_dynptr.skel.h"
17#include "test_cls_redirect_subprogs.skel.h"
18
19#define ENCAP_IP INADDR_LOOPBACK
20#define ENCAP_PORT (1234)
21
22static int duration = 0;
23
24struct addr_port {
25	in_port_t port;
26	union {
27		struct in_addr in_addr;
28		struct in6_addr in6_addr;
29	};
30};
31
32struct tuple {
33	int family;
34	struct addr_port src;
35	struct addr_port dst;
36};
37
38static int start_server(const struct sockaddr *addr, socklen_t len, int type)
39{
40	int fd = socket(addr->sa_family, type, 0);
41	if (CHECK_FAIL(fd == -1))
42		return -1;
43	if (CHECK_FAIL(bind(fd, addr, len) == -1))
44		goto err;
45	if (type == SOCK_STREAM && CHECK_FAIL(listen(fd, 128) == -1))
46		goto err;
47
48	return fd;
49
50err:
51	close(fd);
52	return -1;
53}
54
55static int connect_to_server(const struct sockaddr *addr, socklen_t len,
56			     int type)
57{
58	int fd = socket(addr->sa_family, type, 0);
59	if (CHECK_FAIL(fd == -1))
60		return -1;
61	if (CHECK_FAIL(connect(fd, addr, len)))
62		goto err;
63
64	return fd;
65
66err:
67	close(fd);
68	return -1;
69}
70
71static bool fill_addr_port(const struct sockaddr *sa, struct addr_port *ap)
72{
73	const struct sockaddr_in6 *in6;
74	const struct sockaddr_in *in;
75
76	switch (sa->sa_family) {
77	case AF_INET:
78		in = (const struct sockaddr_in *)sa;
79		ap->in_addr = in->sin_addr;
80		ap->port = in->sin_port;
81		return true;
82
83	case AF_INET6:
84		in6 = (const struct sockaddr_in6 *)sa;
85		ap->in6_addr = in6->sin6_addr;
86		ap->port = in6->sin6_port;
87		return true;
88
89	default:
90		return false;
91	}
92}
93
94static bool set_up_conn(const struct sockaddr *addr, socklen_t len, int type,
95			int *server, int *conn, struct tuple *tuple)
96{
97	struct sockaddr_storage ss;
98	socklen_t slen = sizeof(ss);
99	struct sockaddr *sa = (struct sockaddr *)&ss;
100
101	*server = start_server(addr, len, type);
102	if (*server < 0)
103		return false;
104
105	if (CHECK_FAIL(getsockname(*server, sa, &slen)))
106		goto close_server;
107
108	*conn = connect_to_server(sa, slen, type);
109	if (*conn < 0)
110		goto close_server;
111
112	/* We want to simulate packets arriving at conn, so we have to
113	 * swap src and dst.
114	 */
115	slen = sizeof(ss);
116	if (CHECK_FAIL(getsockname(*conn, sa, &slen)))
117		goto close_conn;
118
119	if (CHECK_FAIL(!fill_addr_port(sa, &tuple->dst)))
120		goto close_conn;
121
122	slen = sizeof(ss);
123	if (CHECK_FAIL(getpeername(*conn, sa, &slen)))
124		goto close_conn;
125
126	if (CHECK_FAIL(!fill_addr_port(sa, &tuple->src)))
127		goto close_conn;
128
129	tuple->family = ss.ss_family;
130	return true;
131
132close_conn:
133	close(*conn);
134	*conn = -1;
135close_server:
136	close(*server);
137	*server = -1;
138	return false;
139}
140
141static socklen_t prepare_addr(struct sockaddr_storage *addr, int family)
142{
143	struct sockaddr_in *addr4;
144	struct sockaddr_in6 *addr6;
145
146	switch (family) {
147	case AF_INET:
148		addr4 = (struct sockaddr_in *)addr;
149		memset(addr4, 0, sizeof(*addr4));
150		addr4->sin_family = family;
151		addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
152		return sizeof(*addr4);
153	case AF_INET6:
154		addr6 = (struct sockaddr_in6 *)addr;
155		memset(addr6, 0, sizeof(*addr6));
156		addr6->sin6_family = family;
157		addr6->sin6_addr = in6addr_loopback;
158		return sizeof(*addr6);
159	default:
160		fprintf(stderr, "Invalid family %d", family);
161		return 0;
162	}
163}
164
165static bool was_decapsulated(struct bpf_test_run_opts *tattr)
166{
167	return tattr->data_size_out < tattr->data_size_in;
168}
169
170enum type {
171	UDP,
172	TCP,
173	__NR_KIND,
174};
175
176enum hops {
177	NO_HOPS,
178	ONE_HOP,
179};
180
181enum flags {
182	NONE,
183	SYN,
184	ACK,
185};
186
187enum conn {
188	KNOWN_CONN,
189	UNKNOWN_CONN,
190};
191
192enum result {
193	ACCEPT,
194	FORWARD,
195};
196
197struct test_cfg {
198	enum type type;
199	enum result result;
200	enum conn conn;
201	enum hops hops;
202	enum flags flags;
203};
204
205static int test_str(void *buf, size_t len, const struct test_cfg *test,
206		    int family)
207{
208	const char *family_str, *type, *conn, *hops, *result, *flags;
209
210	family_str = "IPv4";
211	if (family == AF_INET6)
212		family_str = "IPv6";
213
214	type = "TCP";
215	if (test->type == UDP)
216		type = "UDP";
217
218	conn = "known";
219	if (test->conn == UNKNOWN_CONN)
220		conn = "unknown";
221
222	hops = "no hops";
223	if (test->hops == ONE_HOP)
224		hops = "one hop";
225
226	result = "accept";
227	if (test->result == FORWARD)
228		result = "forward";
229
230	flags = "none";
231	if (test->flags == SYN)
232		flags = "SYN";
233	else if (test->flags == ACK)
234		flags = "ACK";
235
236	return snprintf(buf, len, "%s %s %s %s (%s, flags: %s)", family_str,
237			type, result, conn, hops, flags);
238}
239
240static struct test_cfg tests[] = {
241	{ TCP, ACCEPT, UNKNOWN_CONN, NO_HOPS, SYN },
242	{ TCP, ACCEPT, UNKNOWN_CONN, NO_HOPS, ACK },
243	{ TCP, FORWARD, UNKNOWN_CONN, ONE_HOP, ACK },
244	{ TCP, ACCEPT, KNOWN_CONN, ONE_HOP, ACK },
245	{ UDP, ACCEPT, UNKNOWN_CONN, NO_HOPS, NONE },
246	{ UDP, FORWARD, UNKNOWN_CONN, ONE_HOP, NONE },
247	{ UDP, ACCEPT, KNOWN_CONN, ONE_HOP, NONE },
248};
249
250static void encap_init(encap_headers_t *encap, uint8_t hop_count, uint8_t proto)
251{
252	const uint8_t hlen =
253		(sizeof(struct guehdr) / sizeof(uint32_t)) + hop_count;
254	*encap = (encap_headers_t){
255		.eth = { .h_proto = htons(ETH_P_IP) },
256		.ip = {
257			.ihl = 5,
258			.version = 4,
259			.ttl = IPDEFTTL,
260			.protocol = IPPROTO_UDP,
261			.daddr = htonl(ENCAP_IP)
262		},
263		.udp = {
264			.dest = htons(ENCAP_PORT),
265		},
266		.gue = {
267			.hlen = hlen,
268			.proto_ctype = proto
269		},
270		.unigue = {
271			.hop_count = hop_count
272		},
273	};
274}
275
276static size_t build_input(const struct test_cfg *test, void *const buf,
277			  const struct tuple *tuple)
278{
279	in_port_t sport = tuple->src.port;
280	encap_headers_t encap;
281	struct iphdr ip;
282	struct ipv6hdr ipv6;
283	struct tcphdr tcp;
284	struct udphdr udp;
285	struct in_addr next_hop;
286	uint8_t *p = buf;
287	int proto;
288
289	proto = IPPROTO_IPIP;
290	if (tuple->family == AF_INET6)
291		proto = IPPROTO_IPV6;
292
293	encap_init(&encap, test->hops == ONE_HOP ? 1 : 0, proto);
294	p = mempcpy(p, &encap, sizeof(encap));
295
296	if (test->hops == ONE_HOP) {
297		next_hop = (struct in_addr){ .s_addr = htonl(0x7f000002) };
298		p = mempcpy(p, &next_hop, sizeof(next_hop));
299	}
300
301	proto = IPPROTO_TCP;
302	if (test->type == UDP)
303		proto = IPPROTO_UDP;
304
305	switch (tuple->family) {
306	case AF_INET:
307		ip = (struct iphdr){
308			.ihl = 5,
309			.version = 4,
310			.ttl = IPDEFTTL,
311			.protocol = proto,
312			.saddr = tuple->src.in_addr.s_addr,
313			.daddr = tuple->dst.in_addr.s_addr,
314		};
315		p = mempcpy(p, &ip, sizeof(ip));
316		break;
317	case AF_INET6:
318		ipv6 = (struct ipv6hdr){
319			.version = 6,
320			.hop_limit = IPDEFTTL,
321			.nexthdr = proto,
322			.saddr = tuple->src.in6_addr,
323			.daddr = tuple->dst.in6_addr,
324		};
325		p = mempcpy(p, &ipv6, sizeof(ipv6));
326		break;
327	default:
328		return 0;
329	}
330
331	if (test->conn == UNKNOWN_CONN)
332		sport--;
333
334	switch (test->type) {
335	case TCP:
336		tcp = (struct tcphdr){
337			.source = sport,
338			.dest = tuple->dst.port,
339		};
340		if (test->flags == SYN)
341			tcp.syn = true;
342		if (test->flags == ACK)
343			tcp.ack = true;
344		p = mempcpy(p, &tcp, sizeof(tcp));
345		break;
346	case UDP:
347		udp = (struct udphdr){
348			.source = sport,
349			.dest = tuple->dst.port,
350		};
351		p = mempcpy(p, &udp, sizeof(udp));
352		break;
353	default:
354		return 0;
355	}
356
357	return (void *)p - buf;
358}
359
360static void close_fds(int *fds, int n)
361{
362	int i;
363
364	for (i = 0; i < n; i++)
365		if (fds[i] > 0)
366			close(fds[i]);
367}
368
369static void test_cls_redirect_common(struct bpf_program *prog)
370{
371	LIBBPF_OPTS(bpf_test_run_opts, tattr);
372	int families[] = { AF_INET, AF_INET6 };
373	struct sockaddr_storage ss;
374	struct sockaddr *addr;
375	socklen_t slen;
376	int i, j, err, prog_fd;
377	int servers[__NR_KIND][ARRAY_SIZE(families)] = {};
378	int conns[__NR_KIND][ARRAY_SIZE(families)] = {};
379	struct tuple tuples[__NR_KIND][ARRAY_SIZE(families)];
380
381	addr = (struct sockaddr *)&ss;
382	for (i = 0; i < ARRAY_SIZE(families); i++) {
383		slen = prepare_addr(&ss, families[i]);
384		if (CHECK_FAIL(!slen))
385			goto cleanup;
386
387		if (CHECK_FAIL(!set_up_conn(addr, slen, SOCK_DGRAM,
388					    &servers[UDP][i], &conns[UDP][i],
389					    &tuples[UDP][i])))
390			goto cleanup;
391
392		if (CHECK_FAIL(!set_up_conn(addr, slen, SOCK_STREAM,
393					    &servers[TCP][i], &conns[TCP][i],
394					    &tuples[TCP][i])))
395			goto cleanup;
396	}
397
398	prog_fd = bpf_program__fd(prog);
399	for (i = 0; i < ARRAY_SIZE(tests); i++) {
400		struct test_cfg *test = &tests[i];
401
402		for (j = 0; j < ARRAY_SIZE(families); j++) {
403			struct tuple *tuple = &tuples[test->type][j];
404			char input[256];
405			char tmp[256];
406
407			test_str(tmp, sizeof(tmp), test, tuple->family);
408			if (!test__start_subtest(tmp))
409				continue;
410
411			tattr.data_out = tmp;
412			tattr.data_size_out = sizeof(tmp);
413
414			tattr.data_in = input;
415			tattr.data_size_in = build_input(test, input, tuple);
416			if (CHECK_FAIL(!tattr.data_size_in))
417				continue;
418
419			err = bpf_prog_test_run_opts(prog_fd, &tattr);
420			if (CHECK_FAIL(err))
421				continue;
422
423			if (tattr.retval != TC_ACT_REDIRECT) {
424				PRINT_FAIL("expected TC_ACT_REDIRECT, got %d\n",
425					   tattr.retval);
426				continue;
427			}
428
429			switch (test->result) {
430			case ACCEPT:
431				if (CHECK_FAIL(!was_decapsulated(&tattr)))
432					continue;
433				break;
434			case FORWARD:
435				if (CHECK_FAIL(was_decapsulated(&tattr)))
436					continue;
437				break;
438			default:
439				PRINT_FAIL("unknown result %d\n", test->result);
440				continue;
441			}
442		}
443	}
444
445cleanup:
446	close_fds((int *)servers, sizeof(servers) / sizeof(servers[0][0]));
447	close_fds((int *)conns, sizeof(conns) / sizeof(conns[0][0]));
448}
449
450static void test_cls_redirect_dynptr(void)
451{
452	struct test_cls_redirect_dynptr *skel;
453	int err;
454
455	skel = test_cls_redirect_dynptr__open();
456	if (!ASSERT_OK_PTR(skel, "skel_open"))
457		return;
458
459	skel->rodata->ENCAPSULATION_IP = htonl(ENCAP_IP);
460	skel->rodata->ENCAPSULATION_PORT = htons(ENCAP_PORT);
461
462	err = test_cls_redirect_dynptr__load(skel);
463	if (!ASSERT_OK(err, "skel_load"))
464		goto cleanup;
465
466	test_cls_redirect_common(skel->progs.cls_redirect);
467
468cleanup:
469	test_cls_redirect_dynptr__destroy(skel);
470}
471
472static void test_cls_redirect_inlined(void)
473{
474	struct test_cls_redirect *skel;
475	int err;
476
477	skel = test_cls_redirect__open();
478	if (CHECK(!skel, "skel_open", "failed\n"))
479		return;
480
481	skel->rodata->ENCAPSULATION_IP = htonl(ENCAP_IP);
482	skel->rodata->ENCAPSULATION_PORT = htons(ENCAP_PORT);
483
484	err = test_cls_redirect__load(skel);
485	if (CHECK(err, "skel_load", "failed: %d\n", err))
486		goto cleanup;
487
488	test_cls_redirect_common(skel->progs.cls_redirect);
489
490cleanup:
491	test_cls_redirect__destroy(skel);
492}
493
494static void test_cls_redirect_subprogs(void)
495{
496	struct test_cls_redirect_subprogs *skel;
497	int err;
498
499	skel = test_cls_redirect_subprogs__open();
500	if (CHECK(!skel, "skel_open", "failed\n"))
501		return;
502
503	skel->rodata->ENCAPSULATION_IP = htonl(ENCAP_IP);
504	skel->rodata->ENCAPSULATION_PORT = htons(ENCAP_PORT);
505
506	err = test_cls_redirect_subprogs__load(skel);
507	if (CHECK(err, "skel_load", "failed: %d\n", err))
508		goto cleanup;
509
510	test_cls_redirect_common(skel->progs.cls_redirect);
511
512cleanup:
513	test_cls_redirect_subprogs__destroy(skel);
514}
515
516void test_cls_redirect(void)
517{
518	if (test__start_subtest("cls_redirect_inlined"))
519		test_cls_redirect_inlined();
520	if (test__start_subtest("cls_redirect_subprogs"))
521		test_cls_redirect_subprogs();
522	if (test__start_subtest("cls_redirect_dynptr"))
523		test_cls_redirect_dynptr();
524}
525