1// SPDX-License-Identifier: GPL-2.0
2// Copyright (c) 2020 Cloudflare
3
4#include <errno.h>
5#include <stdbool.h>
6#include <linux/bpf.h>
7
8#include <bpf/bpf_helpers.h>
9
10struct {
11	__uint(type, BPF_MAP_TYPE_SOCKMAP);
12	__uint(max_entries, 2);
13	__type(key, __u32);
14	__type(value, __u64);
15} sock_map SEC(".maps");
16
17struct {
18	__uint(type, BPF_MAP_TYPE_SOCKMAP);
19	__uint(max_entries, 2);
20	__type(key, __u32);
21	__type(value, __u64);
22} nop_map SEC(".maps");
23
24struct {
25	__uint(type, BPF_MAP_TYPE_SOCKHASH);
26	__uint(max_entries, 2);
27	__type(key, __u32);
28	__type(value, __u64);
29} sock_hash SEC(".maps");
30
31struct {
32	__uint(type, BPF_MAP_TYPE_ARRAY);
33	__uint(max_entries, 2);
34	__type(key, int);
35	__type(value, unsigned int);
36} verdict_map SEC(".maps");
37
38struct {
39	__uint(type, BPF_MAP_TYPE_ARRAY);
40	__uint(max_entries, 1);
41	__type(key, int);
42	__type(value, int);
43} parser_map SEC(".maps");
44
45bool test_sockmap = false; /* toggled by user-space */
46bool test_ingress = false; /* toggled by user-space */
47
48SEC("sk_skb/stream_parser")
49int prog_stream_parser(struct __sk_buff *skb)
50{
51	int *value;
52	__u32 key = 0;
53
54	value = bpf_map_lookup_elem(&parser_map, &key);
55	if (value && *value)
56		return *value;
57
58	return skb->len;
59}
60
61SEC("sk_skb/stream_verdict")
62int prog_stream_verdict(struct __sk_buff *skb)
63{
64	unsigned int *count;
65	__u32 zero = 0;
66	int verdict;
67
68	if (test_sockmap)
69		verdict = bpf_sk_redirect_map(skb, &sock_map, zero, 0);
70	else
71		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero, 0);
72
73	count = bpf_map_lookup_elem(&verdict_map, &verdict);
74	if (count)
75		(*count)++;
76
77	return verdict;
78}
79
80SEC("sk_skb")
81int prog_skb_verdict(struct __sk_buff *skb)
82{
83	unsigned int *count;
84	__u32 zero = 0;
85	int verdict;
86
87	if (test_sockmap)
88		verdict = bpf_sk_redirect_map(skb, &sock_map, zero,
89					      test_ingress ? BPF_F_INGRESS : 0);
90	else
91		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero,
92					       test_ingress ? BPF_F_INGRESS : 0);
93
94	count = bpf_map_lookup_elem(&verdict_map, &verdict);
95	if (count)
96		(*count)++;
97
98	return verdict;
99}
100
101SEC("sk_msg")
102int prog_msg_verdict(struct sk_msg_md *msg)
103{
104	unsigned int *count;
105	__u32 zero = 0;
106	int verdict;
107
108	if (test_sockmap)
109		verdict = bpf_msg_redirect_map(msg, &sock_map, zero, 0);
110	else
111		verdict = bpf_msg_redirect_hash(msg, &sock_hash, &zero, 0);
112
113	count = bpf_map_lookup_elem(&verdict_map, &verdict);
114	if (count)
115		(*count)++;
116
117	return verdict;
118}
119
120SEC("sk_reuseport")
121int prog_reuseport(struct sk_reuseport_md *reuse)
122{
123	unsigned int *count;
124	int err, verdict;
125	__u32 zero = 0;
126
127	if (test_sockmap)
128		err = bpf_sk_select_reuseport(reuse, &sock_map, &zero, 0);
129	else
130		err = bpf_sk_select_reuseport(reuse, &sock_hash, &zero, 0);
131	verdict = err ? SK_DROP : SK_PASS;
132
133	count = bpf_map_lookup_elem(&verdict_map, &verdict);
134	if (count)
135		(*count)++;
136
137	return verdict;
138}
139
140char _license[] SEC("license") = "GPL";
141