1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2023 Isovalent */
3#include <stdbool.h>
4#include <linux/bpf.h>
5#include <linux/if_ether.h>
6#include <linux/in.h>
7#include <linux/ip.h>
8#include <linux/ipv6.h>
9#include <linux/tcp.h>
10#include <linux/udp.h>
11#include <bpf/bpf_endian.h>
12#include <bpf/bpf_helpers.h>
13#include <linux/pkt_cls.h>
14
15char LICENSE[] SEC("license") = "GPL";
16
17__u64 sk_cookie_seen;
18__u64 reuseport_executed;
19union {
20	struct tcphdr tcp;
21	struct udphdr udp;
22} headers;
23
24const volatile __u16 dest_port;
25
26struct {
27	__uint(type, BPF_MAP_TYPE_SOCKMAP);
28	__uint(max_entries, 1);
29	__type(key, __u32);
30	__type(value, __u64);
31} sk_map SEC(".maps");
32
33SEC("sk_reuseport")
34int reuse_accept(struct sk_reuseport_md *ctx)
35{
36	reuseport_executed++;
37
38	if (ctx->ip_protocol == IPPROTO_TCP) {
39		if (ctx->data + sizeof(headers.tcp) > ctx->data_end)
40			return SK_DROP;
41
42		if (__builtin_memcmp(&headers.tcp, ctx->data, sizeof(headers.tcp)) != 0)
43			return SK_DROP;
44	} else if (ctx->ip_protocol == IPPROTO_UDP) {
45		if (ctx->data + sizeof(headers.udp) > ctx->data_end)
46			return SK_DROP;
47
48		if (__builtin_memcmp(&headers.udp, ctx->data, sizeof(headers.udp)) != 0)
49			return SK_DROP;
50	} else {
51		return SK_DROP;
52	}
53
54	sk_cookie_seen = bpf_get_socket_cookie(ctx->sk);
55	return SK_PASS;
56}
57
58SEC("sk_reuseport")
59int reuse_drop(struct sk_reuseport_md *ctx)
60{
61	reuseport_executed++;
62	sk_cookie_seen = 0;
63	return SK_DROP;
64}
65
66static int
67assign_sk(struct __sk_buff *skb)
68{
69	int zero = 0, ret = 0;
70	struct bpf_sock *sk;
71
72	sk = bpf_map_lookup_elem(&sk_map, &zero);
73	if (!sk)
74		return TC_ACT_SHOT;
75	ret = bpf_sk_assign(skb, sk, 0);
76	bpf_sk_release(sk);
77	return ret ? TC_ACT_SHOT : TC_ACT_OK;
78}
79
80static bool
81maybe_assign_tcp(struct __sk_buff *skb, struct tcphdr *th)
82{
83	if (th + 1 > (void *)(long)(skb->data_end))
84		return TC_ACT_SHOT;
85
86	if (!th->syn || th->ack || th->dest != bpf_htons(dest_port))
87		return TC_ACT_OK;
88
89	__builtin_memcpy(&headers.tcp, th, sizeof(headers.tcp));
90	return assign_sk(skb);
91}
92
93static bool
94maybe_assign_udp(struct __sk_buff *skb, struct udphdr *uh)
95{
96	if (uh + 1 > (void *)(long)(skb->data_end))
97		return TC_ACT_SHOT;
98
99	if (uh->dest != bpf_htons(dest_port))
100		return TC_ACT_OK;
101
102	__builtin_memcpy(&headers.udp, uh, sizeof(headers.udp));
103	return assign_sk(skb);
104}
105
106SEC("tc")
107int tc_main(struct __sk_buff *skb)
108{
109	void *data_end = (void *)(long)skb->data_end;
110	void *data = (void *)(long)skb->data;
111	struct ethhdr *eth;
112
113	eth = (struct ethhdr *)(data);
114	if (eth + 1 > data_end)
115		return TC_ACT_SHOT;
116
117	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
118		struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
119
120		if (iph + 1 > data_end)
121			return TC_ACT_SHOT;
122
123		if (iph->protocol == IPPROTO_TCP)
124			return maybe_assign_tcp(skb, (struct tcphdr *)(iph + 1));
125		else if (iph->protocol == IPPROTO_UDP)
126			return maybe_assign_udp(skb, (struct udphdr *)(iph + 1));
127		else
128			return TC_ACT_SHOT;
129	} else {
130		struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
131
132		if (ip6h + 1 > data_end)
133			return TC_ACT_SHOT;
134
135		if (ip6h->nexthdr == IPPROTO_TCP)
136			return maybe_assign_tcp(skb, (struct tcphdr *)(ip6h + 1));
137		else if (ip6h->nexthdr == IPPROTO_UDP)
138			return maybe_assign_udp(skb, (struct udphdr *)(ip6h + 1));
139		else
140			return TC_ACT_SHOT;
141	}
142}
143