1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2021 Facebook */
3#include <stdbool.h>
4#include <stdint.h>
5#include <linux/stddef.h>
6#include <linux/if_ether.h>
7#include <linux/in.h>
8#include <linux/in6.h>
9#include <linux/ip.h>
10#include <linux/ipv6.h>
11#include <linux/tcp.h>
12#include <linux/udp.h>
13#include <linux/bpf.h>
14#include <linux/types.h>
15#include <bpf/bpf_endian.h>
16#include <bpf/bpf_helpers.h>
17
18enum pkt_parse_err {
19	NO_ERR,
20	BAD_IP6_HDR,
21	BAD_IP4GUE_HDR,
22	BAD_IP6GUE_HDR,
23};
24
25enum pkt_flag {
26	TUNNEL = 0x1,
27	TCP_SYN = 0x2,
28	QUIC_INITIAL_FLAG = 0x4,
29	TCP_ACK = 0x8,
30	TCP_RST = 0x10
31};
32
33struct v4_lpm_key {
34	__u32 prefixlen;
35	__u32 src;
36};
37
38struct v4_lpm_val {
39	struct v4_lpm_key key;
40	__u8 val;
41};
42
43struct {
44	__uint(type, BPF_MAP_TYPE_HASH);
45	__uint(max_entries, 16);
46	__type(key, struct in6_addr);
47	__type(value, bool);
48} v6_addr_map SEC(".maps");
49
50struct {
51	__uint(type, BPF_MAP_TYPE_HASH);
52	__uint(max_entries, 16);
53	__type(key, __u32);
54	__type(value, bool);
55} v4_addr_map SEC(".maps");
56
57struct {
58	__uint(type, BPF_MAP_TYPE_LPM_TRIE);
59	__uint(max_entries, 16);
60	__uint(key_size, sizeof(struct v4_lpm_key));
61	__uint(value_size, sizeof(struct v4_lpm_val));
62	__uint(map_flags, BPF_F_NO_PREALLOC);
63} v4_lpm_val_map SEC(".maps");
64
65struct {
66	__uint(type, BPF_MAP_TYPE_ARRAY);
67	__uint(max_entries, 16);
68	__type(key, int);
69	__type(value, __u8);
70} tcp_port_map SEC(".maps");
71
72struct {
73	__uint(type, BPF_MAP_TYPE_ARRAY);
74	__uint(max_entries, 16);
75	__type(key, int);
76	__type(value, __u16);
77} udp_port_map SEC(".maps");
78
79enum ip_type { V4 = 1, V6 = 2 };
80
81struct fw_match_info {
82	__u8 v4_src_ip_match;
83	__u8 v6_src_ip_match;
84	__u8 v4_src_prefix_match;
85	__u8 v4_dst_prefix_match;
86	__u8 tcp_dp_match;
87	__u16 udp_sp_match;
88	__u16 udp_dp_match;
89	bool is_tcp;
90	bool is_tcp_syn;
91};
92
93struct pkt_info {
94	enum ip_type type;
95	union {
96		struct iphdr *ipv4;
97		struct ipv6hdr *ipv6;
98	} ip;
99	int sport;
100	int dport;
101	__u16 trans_hdr_offset;
102	__u8 proto;
103	__u8 flags;
104};
105
106static __always_inline struct ethhdr *parse_ethhdr(void *data, void *data_end)
107{
108	struct ethhdr *eth = data;
109
110	if (eth + 1 > data_end)
111		return NULL;
112
113	return eth;
114}
115
116static __always_inline __u8 filter_ipv6_addr(const struct in6_addr *ipv6addr)
117{
118	__u8 *leaf;
119
120	leaf = bpf_map_lookup_elem(&v6_addr_map, ipv6addr);
121
122	return leaf ? *leaf : 0;
123}
124
125static __always_inline __u8 filter_ipv4_addr(const __u32 ipaddr)
126{
127	__u8 *leaf;
128
129	leaf = bpf_map_lookup_elem(&v4_addr_map, &ipaddr);
130
131	return leaf ? *leaf : 0;
132}
133
134static __always_inline __u8 filter_ipv4_lpm(const __u32 ipaddr)
135{
136	struct v4_lpm_key v4_key = {};
137	struct v4_lpm_val *lpm_val;
138
139	v4_key.src = ipaddr;
140	v4_key.prefixlen = 32;
141
142	lpm_val = bpf_map_lookup_elem(&v4_lpm_val_map, &v4_key);
143
144	return lpm_val ? lpm_val->val : 0;
145}
146
147
148static __always_inline void
149filter_src_dst_ip(struct pkt_info* info, struct fw_match_info* match_info)
150{
151	if (info->type == V6) {
152		match_info->v6_src_ip_match =
153			filter_ipv6_addr(&info->ip.ipv6->saddr);
154	} else if (info->type == V4) {
155		match_info->v4_src_ip_match =
156			filter_ipv4_addr(info->ip.ipv4->saddr);
157		match_info->v4_src_prefix_match =
158			filter_ipv4_lpm(info->ip.ipv4->saddr);
159		match_info->v4_dst_prefix_match =
160			filter_ipv4_lpm(info->ip.ipv4->daddr);
161	}
162}
163
164static __always_inline void *
165get_transport_hdr(__u16 offset, void *data, void *data_end)
166{
167	if (offset > 255 || data + offset > data_end)
168		return NULL;
169
170	return data + offset;
171}
172
173static __always_inline bool tcphdr_only_contains_flag(struct tcphdr *tcp,
174						      __u32 FLAG)
175{
176	return (tcp_flag_word(tcp) &
177		(TCP_FLAG_ACK | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)) == FLAG;
178}
179
180static __always_inline void set_tcp_flags(struct pkt_info *info,
181					  struct tcphdr *tcp) {
182	if (tcphdr_only_contains_flag(tcp, TCP_FLAG_SYN))
183		info->flags |= TCP_SYN;
184	else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_ACK))
185		info->flags |= TCP_ACK;
186	else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_RST))
187		info->flags |= TCP_RST;
188}
189
190static __always_inline bool
191parse_tcp(struct pkt_info *info, void *transport_hdr, void *data_end)
192{
193	struct tcphdr *tcp = transport_hdr;
194
195	if (tcp + 1 > data_end)
196		return false;
197
198	info->sport = bpf_ntohs(tcp->source);
199	info->dport = bpf_ntohs(tcp->dest);
200	set_tcp_flags(info, tcp);
201
202	return true;
203}
204
205static __always_inline bool
206parse_udp(struct pkt_info *info, void *transport_hdr, void *data_end)
207{
208	struct udphdr *udp = transport_hdr;
209
210	if (udp + 1 > data_end)
211		return false;
212
213	info->sport = bpf_ntohs(udp->source);
214	info->dport = bpf_ntohs(udp->dest);
215
216	return true;
217}
218
219static __always_inline __u8 filter_tcp_port(int port)
220{
221	__u8 *leaf = bpf_map_lookup_elem(&tcp_port_map, &port);
222
223	return leaf ? *leaf : 0;
224}
225
226static __always_inline __u16 filter_udp_port(int port)
227{
228	__u16 *leaf = bpf_map_lookup_elem(&udp_port_map, &port);
229
230	return leaf ? *leaf : 0;
231}
232
233static __always_inline bool
234filter_transport_hdr(void *transport_hdr, void *data_end,
235		     struct pkt_info *info, struct fw_match_info *match_info)
236{
237	if (info->proto == IPPROTO_TCP) {
238		if (!parse_tcp(info, transport_hdr, data_end))
239			return false;
240
241		match_info->is_tcp = true;
242		match_info->is_tcp_syn = (info->flags & TCP_SYN) > 0;
243
244		match_info->tcp_dp_match = filter_tcp_port(info->dport);
245	} else if (info->proto == IPPROTO_UDP) {
246		if (!parse_udp(info, transport_hdr, data_end))
247			return false;
248
249		match_info->udp_dp_match = filter_udp_port(info->dport);
250		match_info->udp_sp_match = filter_udp_port(info->sport);
251	}
252
253	return true;
254}
255
256static __always_inline __u8
257parse_gue_v6(struct pkt_info *info, struct ipv6hdr *ip6h, void *data_end)
258{
259	struct udphdr *udp = (struct udphdr *)(ip6h + 1);
260	void *encap_data = udp + 1;
261
262	if (udp + 1 > data_end)
263		return BAD_IP6_HDR;
264
265	if (udp->dest != bpf_htons(6666))
266		return NO_ERR;
267
268	info->flags |= TUNNEL;
269
270	if (encap_data + 1 > data_end)
271		return BAD_IP6GUE_HDR;
272
273	if (*(__u8 *)encap_data & 0x30) {
274		struct ipv6hdr *inner_ip6h = encap_data;
275
276		if (inner_ip6h + 1 > data_end)
277			return BAD_IP6GUE_HDR;
278
279		info->type = V6;
280		info->proto = inner_ip6h->nexthdr;
281		info->ip.ipv6 = inner_ip6h;
282		info->trans_hdr_offset += sizeof(struct ipv6hdr) + sizeof(struct udphdr);
283	} else {
284		struct iphdr *inner_ip4h = encap_data;
285
286		if (inner_ip4h + 1 > data_end)
287			return BAD_IP6GUE_HDR;
288
289		info->type = V4;
290		info->proto = inner_ip4h->protocol;
291		info->ip.ipv4 = inner_ip4h;
292		info->trans_hdr_offset += sizeof(struct iphdr) + sizeof(struct udphdr);
293	}
294
295	return NO_ERR;
296}
297
298static __always_inline __u8 parse_ipv6_gue(struct pkt_info *info,
299					   void *data, void *data_end)
300{
301	struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
302
303	if (ip6h + 1 > data_end)
304		return BAD_IP6_HDR;
305
306	info->proto = ip6h->nexthdr;
307	info->ip.ipv6 = ip6h;
308	info->type = V6;
309	info->trans_hdr_offset = sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
310
311	if (info->proto == IPPROTO_UDP)
312		return parse_gue_v6(info, ip6h, data_end);
313
314	return NO_ERR;
315}
316
317SEC("xdp")
318int edgewall(struct xdp_md *ctx)
319{
320	void *data_end = (void *)(long)(ctx->data_end);
321	void *data = (void *)(long)(ctx->data);
322	struct fw_match_info match_info = {};
323	struct pkt_info info = {};
324	void *transport_hdr;
325	struct ethhdr *eth;
326	bool filter_res;
327	__u32 proto;
328
329	eth = parse_ethhdr(data, data_end);
330	if (!eth)
331		return XDP_DROP;
332
333	proto = eth->h_proto;
334	if (proto != bpf_htons(ETH_P_IPV6))
335		return XDP_DROP;
336
337	if (parse_ipv6_gue(&info, data, data_end))
338		return XDP_DROP;
339
340	if (info.proto == IPPROTO_ICMPV6)
341		return XDP_PASS;
342
343	if (info.proto != IPPROTO_TCP && info.proto != IPPROTO_UDP)
344		return XDP_DROP;
345
346	filter_src_dst_ip(&info, &match_info);
347
348	transport_hdr = get_transport_hdr(info.trans_hdr_offset, data,
349					  data_end);
350	if (!transport_hdr)
351		return XDP_DROP;
352
353	filter_res = filter_transport_hdr(transport_hdr, data_end,
354					  &info, &match_info);
355	if (!filter_res)
356		return XDP_DROP;
357
358	if (match_info.is_tcp && !match_info.is_tcp_syn)
359		return XDP_PASS;
360
361	return XDP_DROP;
362}
363
364char LICENSE[] SEC("license") = "GPL";
365