1// SPDX-License-Identifier: GPL-2.0-only
2
3#include <sys/types.h>
4#include <sys/epoll.h>
5#include <sys/socket.h>
6#include <linux/netlink.h>
7#include <linux/connector.h>
8#include <linux/cn_proc.h>
9
10#include <stddef.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <unistd.h>
14#include <strings.h>
15#include <errno.h>
16#include <signal.h>
17#include <string.h>
18
19#include "../kselftest.h"
20
21#define NL_MESSAGE_SIZE (sizeof(struct nlmsghdr) + sizeof(struct cn_msg) + \
22			 sizeof(struct proc_input))
23#define NL_MESSAGE_SIZE_NF (sizeof(struct nlmsghdr) + sizeof(struct cn_msg) + \
24			 sizeof(int))
25
26#define MAX_EVENTS 1
27
28volatile static int interrupted;
29static int nl_sock, ret_errno, tcount;
30static struct epoll_event evn;
31
32static int filter;
33
34#ifdef ENABLE_PRINTS
35#define Printf printf
36#else
37#define Printf ksft_print_msg
38#endif
39
40int send_message(void *pinp)
41{
42	char buff[NL_MESSAGE_SIZE];
43	struct nlmsghdr *hdr;
44	struct cn_msg *msg;
45
46	hdr = (struct nlmsghdr *)buff;
47	if (filter)
48		hdr->nlmsg_len = NL_MESSAGE_SIZE;
49	else
50		hdr->nlmsg_len = NL_MESSAGE_SIZE_NF;
51	hdr->nlmsg_type = NLMSG_DONE;
52	hdr->nlmsg_flags = 0;
53	hdr->nlmsg_seq = 0;
54	hdr->nlmsg_pid = getpid();
55
56	msg = (struct cn_msg *)NLMSG_DATA(hdr);
57	msg->id.idx = CN_IDX_PROC;
58	msg->id.val = CN_VAL_PROC;
59	msg->seq = 0;
60	msg->ack = 0;
61	msg->flags = 0;
62
63	if (filter) {
64		msg->len = sizeof(struct proc_input);
65		((struct proc_input *)msg->data)->mcast_op =
66			((struct proc_input *)pinp)->mcast_op;
67		((struct proc_input *)msg->data)->event_type =
68			((struct proc_input *)pinp)->event_type;
69	} else {
70		msg->len = sizeof(int);
71		*(int *)msg->data = *(enum proc_cn_mcast_op *)pinp;
72	}
73
74	if (send(nl_sock, hdr, hdr->nlmsg_len, 0) == -1) {
75		ret_errno = errno;
76		perror("send failed");
77		return -3;
78	}
79	return 0;
80}
81
82int register_proc_netlink(int *efd, void *input)
83{
84	struct sockaddr_nl sa_nl;
85	int err = 0, epoll_fd;
86
87	nl_sock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_CONNECTOR);
88
89	if (nl_sock == -1) {
90		ret_errno = errno;
91		perror("socket failed");
92		return -1;
93	}
94
95	bzero(&sa_nl, sizeof(sa_nl));
96	sa_nl.nl_family = AF_NETLINK;
97	sa_nl.nl_groups = CN_IDX_PROC;
98	sa_nl.nl_pid    = getpid();
99
100	if (bind(nl_sock, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) == -1) {
101		ret_errno = errno;
102		perror("bind failed");
103		return -2;
104	}
105
106	epoll_fd = epoll_create1(EPOLL_CLOEXEC);
107	if (epoll_fd < 0) {
108		ret_errno = errno;
109		perror("epoll_create1 failed");
110		return -2;
111	}
112
113	err = send_message(input);
114
115	if (err < 0)
116		return err;
117
118	evn.events = EPOLLIN;
119	evn.data.fd = nl_sock;
120	if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, nl_sock, &evn) < 0) {
121		ret_errno = errno;
122		perror("epoll_ctl failed");
123		return -3;
124	}
125	*efd = epoll_fd;
126	return 0;
127}
128
129static void sigint(int sig)
130{
131	interrupted = 1;
132}
133
134int handle_packet(char *buff, int fd, struct proc_event *event)
135{
136	struct nlmsghdr *hdr;
137
138	hdr = (struct nlmsghdr *)buff;
139
140	if (hdr->nlmsg_type == NLMSG_ERROR) {
141		perror("NLMSG_ERROR error\n");
142		return -3;
143	} else if (hdr->nlmsg_type == NLMSG_DONE) {
144		event = (struct proc_event *)
145			((struct cn_msg *)NLMSG_DATA(hdr))->data;
146		tcount++;
147		switch (event->what) {
148		case PROC_EVENT_EXIT:
149			Printf("Exit process %d (tgid %d) with code %d, signal %d\n",
150			       event->event_data.exit.process_pid,
151			       event->event_data.exit.process_tgid,
152			       event->event_data.exit.exit_code,
153			       event->event_data.exit.exit_signal);
154			break;
155		case PROC_EVENT_FORK:
156			Printf("Fork process %d (tgid %d), parent %d (tgid %d)\n",
157			       event->event_data.fork.child_pid,
158			       event->event_data.fork.child_tgid,
159			       event->event_data.fork.parent_pid,
160			       event->event_data.fork.parent_tgid);
161			break;
162		case PROC_EVENT_EXEC:
163			Printf("Exec process %d (tgid %d)\n",
164			       event->event_data.exec.process_pid,
165			       event->event_data.exec.process_tgid);
166			break;
167		case PROC_EVENT_UID:
168			Printf("UID process %d (tgid %d) uid %d euid %d\n",
169			       event->event_data.id.process_pid,
170			       event->event_data.id.process_tgid,
171			       event->event_data.id.r.ruid,
172			       event->event_data.id.e.euid);
173			break;
174		case PROC_EVENT_GID:
175			Printf("GID process %d (tgid %d) gid %d egid %d\n",
176			       event->event_data.id.process_pid,
177			       event->event_data.id.process_tgid,
178			       event->event_data.id.r.rgid,
179			       event->event_data.id.e.egid);
180			break;
181		case PROC_EVENT_SID:
182			Printf("SID process %d (tgid %d)\n",
183			       event->event_data.sid.process_pid,
184			       event->event_data.sid.process_tgid);
185			break;
186		case PROC_EVENT_PTRACE:
187			Printf("Ptrace process %d (tgid %d), Tracer %d (tgid %d)\n",
188			       event->event_data.ptrace.process_pid,
189			       event->event_data.ptrace.process_tgid,
190			       event->event_data.ptrace.tracer_pid,
191			       event->event_data.ptrace.tracer_tgid);
192			break;
193		case PROC_EVENT_COMM:
194			Printf("Comm process %d (tgid %d) comm %s\n",
195			       event->event_data.comm.process_pid,
196			       event->event_data.comm.process_tgid,
197			       event->event_data.comm.comm);
198			break;
199		case PROC_EVENT_COREDUMP:
200			Printf("Coredump process %d (tgid %d) parent %d, (tgid %d)\n",
201			       event->event_data.coredump.process_pid,
202			       event->event_data.coredump.process_tgid,
203			       event->event_data.coredump.parent_pid,
204			       event->event_data.coredump.parent_tgid);
205			break;
206		default:
207			break;
208		}
209	}
210	return 0;
211}
212
213int handle_events(int epoll_fd, struct proc_event *pev)
214{
215	char buff[CONNECTOR_MAX_MSG_SIZE];
216	struct epoll_event ev[MAX_EVENTS];
217	int i, event_count = 0, err = 0;
218
219	event_count = epoll_wait(epoll_fd, ev, MAX_EVENTS, -1);
220	if (event_count < 0) {
221		ret_errno = errno;
222		if (ret_errno != EINTR)
223			perror("epoll_wait failed");
224		return -3;
225	}
226	for (i = 0; i < event_count; i++) {
227		if (!(ev[i].events & EPOLLIN))
228			continue;
229		if (recv(ev[i].data.fd, buff, sizeof(buff), 0) == -1) {
230			ret_errno = errno;
231			perror("recv failed");
232			return -3;
233		}
234		err = handle_packet(buff, ev[i].data.fd, pev);
235		if (err < 0)
236			return err;
237	}
238	return 0;
239}
240
241int main(int argc, char *argv[])
242{
243	int epoll_fd, err;
244	struct proc_event proc_ev;
245	struct proc_input input;
246
247	signal(SIGINT, sigint);
248
249	if (argc > 2) {
250		printf("Expected 0(assume no-filter) or 1 argument(-f)\n");
251		exit(KSFT_SKIP);
252	}
253
254	if (argc == 2) {
255		if (strcmp(argv[1], "-f") == 0) {
256			filter = 1;
257		} else {
258			printf("Valid option : -f (for filter feature)\n");
259			exit(KSFT_SKIP);
260		}
261	}
262
263	if (filter) {
264		input.event_type = PROC_EVENT_NONZERO_EXIT;
265		input.mcast_op = PROC_CN_MCAST_LISTEN;
266		err = register_proc_netlink(&epoll_fd, (void*)&input);
267	} else {
268		enum proc_cn_mcast_op op = PROC_CN_MCAST_LISTEN;
269		err = register_proc_netlink(&epoll_fd, (void*)&op);
270	}
271
272	if (err < 0) {
273		if (err == -2)
274			close(nl_sock);
275		if (err == -3) {
276			close(nl_sock);
277			close(epoll_fd);
278		}
279		exit(1);
280	}
281
282	while (!interrupted) {
283		err = handle_events(epoll_fd, &proc_ev);
284		if (err < 0) {
285			if (ret_errno == EINTR)
286				continue;
287			if (err == -2)
288				close(nl_sock);
289			if (err == -3) {
290				close(nl_sock);
291				close(epoll_fd);
292			}
293			exit(1);
294		}
295	}
296
297	if (filter) {
298		input.mcast_op = PROC_CN_MCAST_IGNORE;
299		send_message((void*)&input);
300	} else {
301		enum proc_cn_mcast_op op = PROC_CN_MCAST_IGNORE;
302		send_message((void*)&op);
303	}
304
305	close(epoll_fd);
306	close(nl_sock);
307
308	printf("Done total count: %d\n", tcount);
309	exit(0);
310}
311