1// SPDX-License-Identifier: LGPL-2.1 OR BSD-2-Clause
2/* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
3
4#include <stdnoreturn.h>
5#include <stdlib.h>
6#include <stdio.h>
7#include <string.h>
8#include <errno.h>
9#include <unistd.h>
10#include <getopt.h>
11#include <signal.h>
12#include <sys/types.h>
13#include <bpf/bpf.h>
14#include <bpf/libbpf.h>
15#include <net/if.h>
16#include <linux/if_link.h>
17#include <linux/limits.h>
18
19static unsigned int ifindex;
20static __u32 attached_prog_id;
21static bool attached_tc;
22
23static void noreturn cleanup(int sig)
24{
25	LIBBPF_OPTS(bpf_xdp_attach_opts, opts);
26	int prog_fd;
27	int err;
28
29	if (attached_prog_id == 0)
30		exit(0);
31
32	if (attached_tc) {
33		LIBBPF_OPTS(bpf_tc_hook, hook,
34			    .ifindex = ifindex,
35			    .attach_point = BPF_TC_INGRESS);
36
37		err = bpf_tc_hook_destroy(&hook);
38		if (err < 0) {
39			fprintf(stderr, "Error: bpf_tc_hook_destroy: %s\n", strerror(-err));
40			fprintf(stderr, "Failed to destroy the TC hook\n");
41			exit(1);
42		}
43		exit(0);
44	}
45
46	prog_fd = bpf_prog_get_fd_by_id(attached_prog_id);
47	if (prog_fd < 0) {
48		fprintf(stderr, "Error: bpf_prog_get_fd_by_id: %s\n", strerror(-prog_fd));
49		err = bpf_xdp_attach(ifindex, -1, 0, NULL);
50		if (err < 0) {
51			fprintf(stderr, "Error: bpf_set_link_xdp_fd: %s\n", strerror(-err));
52			fprintf(stderr, "Failed to detach XDP program\n");
53			exit(1);
54		}
55	} else {
56		opts.old_prog_fd = prog_fd;
57		err = bpf_xdp_attach(ifindex, -1, XDP_FLAGS_REPLACE, &opts);
58		close(prog_fd);
59		if (err < 0) {
60			fprintf(stderr, "Error: bpf_set_link_xdp_fd_opts: %s\n", strerror(-err));
61			/* Not an error if already replaced by someone else. */
62			if (err != -EEXIST) {
63				fprintf(stderr, "Failed to detach XDP program\n");
64				exit(1);
65			}
66		}
67	}
68	exit(0);
69}
70
71static noreturn void usage(const char *progname)
72{
73	fprintf(stderr, "Usage: %s [--iface <iface>|--prog <prog_id>] [--mss4 <mss ipv4> --mss6 <mss ipv6> --wscale <wscale> --ttl <ttl>] [--ports <port1>,<port2>,...] [--single] [--tc]\n",
74		progname);
75	exit(1);
76}
77
78static unsigned long parse_arg_ul(const char *progname, const char *arg, unsigned long limit)
79{
80	unsigned long res;
81	char *endptr;
82
83	errno = 0;
84	res = strtoul(arg, &endptr, 10);
85	if (errno != 0 || *endptr != '\0' || arg[0] == '\0' || res > limit)
86		usage(progname);
87
88	return res;
89}
90
91static void parse_options(int argc, char *argv[], unsigned int *ifindex, __u32 *prog_id,
92			  __u64 *tcpipopts, char **ports, bool *single, bool *tc)
93{
94	static struct option long_options[] = {
95		{ "help", no_argument, NULL, 'h' },
96		{ "iface", required_argument, NULL, 'i' },
97		{ "prog", required_argument, NULL, 'x' },
98		{ "mss4", required_argument, NULL, 4 },
99		{ "mss6", required_argument, NULL, 6 },
100		{ "wscale", required_argument, NULL, 'w' },
101		{ "ttl", required_argument, NULL, 't' },
102		{ "ports", required_argument, NULL, 'p' },
103		{ "single", no_argument, NULL, 's' },
104		{ "tc", no_argument, NULL, 'c' },
105		{ NULL, 0, NULL, 0 },
106	};
107	unsigned long mss4, wscale, ttl;
108	unsigned long long mss6;
109	unsigned int tcpipopts_mask = 0;
110
111	if (argc < 2)
112		usage(argv[0]);
113
114	*ifindex = 0;
115	*prog_id = 0;
116	*tcpipopts = 0;
117	*ports = NULL;
118	*single = false;
119	*tc = false;
120
121	while (true) {
122		int opt;
123
124		opt = getopt_long(argc, argv, "", long_options, NULL);
125		if (opt == -1)
126			break;
127
128		switch (opt) {
129		case 'h':
130			usage(argv[0]);
131			break;
132		case 'i':
133			*ifindex = if_nametoindex(optarg);
134			if (*ifindex == 0)
135				usage(argv[0]);
136			break;
137		case 'x':
138			*prog_id = parse_arg_ul(argv[0], optarg, UINT32_MAX);
139			if (*prog_id == 0)
140				usage(argv[0]);
141			break;
142		case 4:
143			mss4 = parse_arg_ul(argv[0], optarg, UINT16_MAX);
144			tcpipopts_mask |= 1 << 0;
145			break;
146		case 6:
147			mss6 = parse_arg_ul(argv[0], optarg, UINT16_MAX);
148			tcpipopts_mask |= 1 << 1;
149			break;
150		case 'w':
151			wscale = parse_arg_ul(argv[0], optarg, 14);
152			tcpipopts_mask |= 1 << 2;
153			break;
154		case 't':
155			ttl = parse_arg_ul(argv[0], optarg, UINT8_MAX);
156			tcpipopts_mask |= 1 << 3;
157			break;
158		case 'p':
159			*ports = optarg;
160			break;
161		case 's':
162			*single = true;
163			break;
164		case 'c':
165			*tc = true;
166			break;
167		default:
168			usage(argv[0]);
169		}
170	}
171	if (optind < argc)
172		usage(argv[0]);
173
174	if (tcpipopts_mask == 0xf) {
175		if (mss4 == 0 || mss6 == 0 || wscale == 0 || ttl == 0)
176			usage(argv[0]);
177		*tcpipopts = (mss6 << 32) | (ttl << 24) | (wscale << 16) | mss4;
178	} else if (tcpipopts_mask != 0) {
179		usage(argv[0]);
180	}
181
182	if (*ifindex != 0 && *prog_id != 0)
183		usage(argv[0]);
184	if (*ifindex == 0 && *prog_id == 0)
185		usage(argv[0]);
186}
187
188static int syncookie_attach(const char *argv0, unsigned int ifindex, bool tc)
189{
190	struct bpf_prog_info info = {};
191	__u32 info_len = sizeof(info);
192	char xdp_filename[PATH_MAX];
193	struct bpf_program *prog;
194	struct bpf_object *obj;
195	int prog_fd;
196	int err;
197
198	snprintf(xdp_filename, sizeof(xdp_filename), "%s_kern.bpf.o", argv0);
199	obj = bpf_object__open_file(xdp_filename, NULL);
200	err = libbpf_get_error(obj);
201	if (err < 0) {
202		fprintf(stderr, "Error: bpf_object__open_file: %s\n", strerror(-err));
203		return err;
204	}
205
206	err = bpf_object__load(obj);
207	if (err < 0) {
208		fprintf(stderr, "Error: bpf_object__open_file: %s\n", strerror(-err));
209		return err;
210	}
211
212	prog = bpf_object__find_program_by_name(obj, tc ? "syncookie_tc" : "syncookie_xdp");
213	if (!prog) {
214		fprintf(stderr, "Error: bpf_object__find_program_by_name: program was not found\n");
215		return -ENOENT;
216	}
217
218	prog_fd = bpf_program__fd(prog);
219
220	err = bpf_prog_get_info_by_fd(prog_fd, &info, &info_len);
221	if (err < 0) {
222		fprintf(stderr, "Error: bpf_prog_get_info_by_fd: %s\n",
223			strerror(-err));
224		goto out;
225	}
226	attached_tc = tc;
227	attached_prog_id = info.id;
228	signal(SIGINT, cleanup);
229	signal(SIGTERM, cleanup);
230	if (tc) {
231		LIBBPF_OPTS(bpf_tc_hook, hook,
232			    .ifindex = ifindex,
233			    .attach_point = BPF_TC_INGRESS);
234		LIBBPF_OPTS(bpf_tc_opts, opts,
235			    .handle = 1,
236			    .priority = 1,
237			    .prog_fd = prog_fd);
238
239		err = bpf_tc_hook_create(&hook);
240		if (err < 0) {
241			fprintf(stderr, "Error: bpf_tc_hook_create: %s\n",
242				strerror(-err));
243			goto fail;
244		}
245		err = bpf_tc_attach(&hook, &opts);
246		if (err < 0) {
247			fprintf(stderr, "Error: bpf_tc_attach: %s\n",
248				strerror(-err));
249			goto fail;
250		}
251
252	} else {
253		err = bpf_xdp_attach(ifindex, prog_fd,
254				     XDP_FLAGS_UPDATE_IF_NOEXIST, NULL);
255		if (err < 0) {
256			fprintf(stderr, "Error: bpf_set_link_xdp_fd: %s\n",
257				strerror(-err));
258			goto fail;
259		}
260	}
261	err = 0;
262out:
263	bpf_object__close(obj);
264	return err;
265fail:
266	signal(SIGINT, SIG_DFL);
267	signal(SIGTERM, SIG_DFL);
268	attached_prog_id = 0;
269	goto out;
270}
271
272static int syncookie_open_bpf_maps(__u32 prog_id, int *values_map_fd, int *ports_map_fd)
273{
274	struct bpf_prog_info prog_info;
275	__u32 map_ids[8];
276	__u32 info_len;
277	int prog_fd;
278	int err;
279	int i;
280
281	*values_map_fd = -1;
282	*ports_map_fd = -1;
283
284	prog_fd = bpf_prog_get_fd_by_id(prog_id);
285	if (prog_fd < 0) {
286		fprintf(stderr, "Error: bpf_prog_get_fd_by_id: %s\n", strerror(-prog_fd));
287		return prog_fd;
288	}
289
290	prog_info = (struct bpf_prog_info) {
291		.nr_map_ids = 8,
292		.map_ids = (__u64)(unsigned long)map_ids,
293	};
294	info_len = sizeof(prog_info);
295
296	err = bpf_prog_get_info_by_fd(prog_fd, &prog_info, &info_len);
297	if (err != 0) {
298		fprintf(stderr, "Error: bpf_prog_get_info_by_fd: %s\n",
299			strerror(-err));
300		goto out;
301	}
302
303	if (prog_info.nr_map_ids < 2) {
304		fprintf(stderr, "Error: Found %u BPF maps, expected at least 2\n",
305			prog_info.nr_map_ids);
306		err = -ENOENT;
307		goto out;
308	}
309
310	for (i = 0; i < prog_info.nr_map_ids; i++) {
311		struct bpf_map_info map_info = {};
312		int map_fd;
313
314		err = bpf_map_get_fd_by_id(map_ids[i]);
315		if (err < 0) {
316			fprintf(stderr, "Error: bpf_map_get_fd_by_id: %s\n", strerror(-err));
317			goto err_close_map_fds;
318		}
319		map_fd = err;
320
321		info_len = sizeof(map_info);
322		err = bpf_map_get_info_by_fd(map_fd, &map_info, &info_len);
323		if (err != 0) {
324			fprintf(stderr, "Error: bpf_map_get_info_by_fd: %s\n",
325				strerror(-err));
326			close(map_fd);
327			goto err_close_map_fds;
328		}
329		if (strcmp(map_info.name, "values") == 0) {
330			*values_map_fd = map_fd;
331			continue;
332		}
333		if (strcmp(map_info.name, "allowed_ports") == 0) {
334			*ports_map_fd = map_fd;
335			continue;
336		}
337		close(map_fd);
338	}
339
340	if (*values_map_fd != -1 && *ports_map_fd != -1) {
341		err = 0;
342		goto out;
343	}
344
345	err = -ENOENT;
346
347err_close_map_fds:
348	if (*values_map_fd != -1)
349		close(*values_map_fd);
350	if (*ports_map_fd != -1)
351		close(*ports_map_fd);
352	*values_map_fd = -1;
353	*ports_map_fd = -1;
354
355out:
356	close(prog_fd);
357	return err;
358}
359
360int main(int argc, char *argv[])
361{
362	int values_map_fd, ports_map_fd;
363	__u64 tcpipopts;
364	bool firstiter;
365	__u64 prevcnt;
366	__u32 prog_id;
367	char *ports;
368	bool single;
369	int err = 0;
370	bool tc;
371
372	parse_options(argc, argv, &ifindex, &prog_id, &tcpipopts, &ports,
373		      &single, &tc);
374
375	if (prog_id == 0) {
376		if (!tc) {
377			err = bpf_xdp_query_id(ifindex, 0, &prog_id);
378			if (err < 0) {
379				fprintf(stderr, "Error: bpf_get_link_xdp_id: %s\n",
380					strerror(-err));
381				goto out;
382			}
383		}
384		if (prog_id == 0) {
385			err = syncookie_attach(argv[0], ifindex, tc);
386			if (err < 0)
387				goto out;
388			prog_id = attached_prog_id;
389		}
390	}
391
392	err = syncookie_open_bpf_maps(prog_id, &values_map_fd, &ports_map_fd);
393	if (err < 0)
394		goto out;
395
396	if (ports) {
397		__u16 port_last = 0;
398		__u32 port_idx = 0;
399		char *p = ports;
400
401		fprintf(stderr, "Replacing allowed ports\n");
402
403		while (p && *p != '\0') {
404			char *token = strsep(&p, ",");
405			__u16 port;
406
407			port = parse_arg_ul(argv[0], token, UINT16_MAX);
408			err = bpf_map_update_elem(ports_map_fd, &port_idx, &port, BPF_ANY);
409			if (err != 0) {
410				fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
411				fprintf(stderr, "Failed to add port %u (index %u)\n",
412					port, port_idx);
413				goto out_close_maps;
414			}
415			fprintf(stderr, "Added port %u\n", port);
416			port_idx++;
417		}
418		err = bpf_map_update_elem(ports_map_fd, &port_idx, &port_last, BPF_ANY);
419		if (err != 0) {
420			fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
421			fprintf(stderr, "Failed to add the terminator value 0 (index %u)\n",
422				port_idx);
423			goto out_close_maps;
424		}
425	}
426
427	if (tcpipopts) {
428		__u32 key = 0;
429
430		fprintf(stderr, "Replacing TCP/IP options\n");
431
432		err = bpf_map_update_elem(values_map_fd, &key, &tcpipopts, BPF_ANY);
433		if (err != 0) {
434			fprintf(stderr, "Error: bpf_map_update_elem: %s\n", strerror(-err));
435			goto out_close_maps;
436		}
437	}
438
439	if ((ports || tcpipopts) && attached_prog_id == 0 && !single)
440		goto out_close_maps;
441
442	prevcnt = 0;
443	firstiter = true;
444	while (true) {
445		__u32 key = 1;
446		__u64 value;
447
448		err = bpf_map_lookup_elem(values_map_fd, &key, &value);
449		if (err != 0) {
450			fprintf(stderr, "Error: bpf_map_lookup_elem: %s\n", strerror(-err));
451			goto out_close_maps;
452		}
453		if (firstiter) {
454			prevcnt = value;
455			firstiter = false;
456		}
457		if (single) {
458			printf("Total SYNACKs generated: %llu\n", value);
459			break;
460		}
461		printf("SYNACKs generated: %llu (total %llu)\n", value - prevcnt, value);
462		prevcnt = value;
463		sleep(1);
464	}
465
466out_close_maps:
467	close(values_map_fd);
468	close(ports_map_fd);
469out:
470	return err == 0 ? 0 : 1;
471}
472