1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2020 Facebook */
3
4#include <vmlinux.h>
5#include <bpf/bpf_tracing.h>
6#include <bpf/bpf_core_read.h>
7#include <bpf/bpf_helpers.h>
8
9struct sk_stg {
10	__u32 pid;
11	__u32 last_notclose_state;
12	char comm[16];
13};
14
15struct {
16	__uint(type, BPF_MAP_TYPE_SK_STORAGE);
17	__uint(map_flags, BPF_F_NO_PREALLOC);
18	__type(key, int);
19	__type(value, struct sk_stg);
20} sk_stg_map SEC(".maps");
21
22/* Testing delete */
23struct {
24	__uint(type, BPF_MAP_TYPE_SK_STORAGE);
25	__uint(map_flags, BPF_F_NO_PREALLOC);
26	__type(key, int);
27	__type(value, int);
28} del_sk_stg_map SEC(".maps");
29
30char task_comm[16] = "";
31
32SEC("tp_btf/inet_sock_set_state")
33int BPF_PROG(trace_inet_sock_set_state, struct sock *sk, int oldstate,
34	     int newstate)
35{
36	struct sk_stg *stg;
37
38	if (newstate == BPF_TCP_CLOSE)
39		return 0;
40
41	stg = bpf_sk_storage_get(&sk_stg_map, sk, 0,
42				 BPF_SK_STORAGE_GET_F_CREATE);
43	if (!stg)
44		return 0;
45
46	stg->last_notclose_state = newstate;
47
48	bpf_sk_storage_delete(&del_sk_stg_map, sk);
49
50	return 0;
51}
52
53static void set_task_info(struct sock *sk)
54{
55	struct task_struct *task;
56	struct sk_stg *stg;
57
58	stg = bpf_sk_storage_get(&sk_stg_map, sk, 0,
59				 BPF_SK_STORAGE_GET_F_CREATE);
60	if (!stg)
61		return;
62
63	stg->pid = bpf_get_current_pid_tgid();
64
65	task = (struct task_struct *)bpf_get_current_task();
66	bpf_core_read_str(&stg->comm, sizeof(stg->comm), &task->comm);
67	bpf_core_read_str(&task_comm, sizeof(task_comm), &task->comm);
68}
69
70SEC("fentry/inet_csk_listen_start")
71int BPF_PROG(trace_inet_csk_listen_start, struct sock *sk)
72{
73	set_task_info(sk);
74
75	return 0;
76}
77
78SEC("fentry/tcp_connect")
79int BPF_PROG(trace_tcp_connect, struct sock *sk)
80{
81	set_task_info(sk);
82
83	return 0;
84}
85
86SEC("fexit/inet_csk_accept")
87int BPF_PROG(inet_csk_accept, struct sock *sk, int flags, int *err, bool kern,
88	     struct sock *accepted_sk)
89{
90	set_task_info(accepted_sk);
91
92	return 0;
93}
94
95SEC("tp_btf/tcp_retransmit_synack")
96int BPF_PROG(tcp_retransmit_synack, struct sock* sk, struct request_sock* req)
97{
98	/* load only test */
99	bpf_sk_storage_get(&sk_stg_map, sk, 0, 0);
100	bpf_sk_storage_get(&sk_stg_map, req->sk, 0, 0);
101	return 0;
102}
103
104SEC("tp_btf/tcp_bad_csum")
105int BPF_PROG(tcp_bad_csum, struct sk_buff* skb)
106{
107	bpf_sk_storage_get(&sk_stg_map, skb->sk, 0, 0);
108	return 0;
109}
110
111char _license[] SEC("license") = "GPL";
112