1// SPDX-License-Identifier: GPL-2.0-only
2/* Copyright (c) 2020 Facebook */
3
4#include <linux/init.h>
5#include <linux/namei.h>
6#include <linux/pid_namespace.h>
7#include <linux/fs.h>
8#include <linux/fdtable.h>
9#include <linux/filter.h>
10#include <linux/bpf_mem_alloc.h>
11#include <linux/btf_ids.h>
12#include <linux/mm_types.h>
13#include "mmap_unlock_work.h"
14
15static const char * const iter_task_type_names[] = {
16	"ALL",
17	"TID",
18	"PID",
19};
20
21struct bpf_iter_seq_task_common {
22	struct pid_namespace *ns;
23	enum bpf_iter_task_type	type;
24	u32 pid;
25	u32 pid_visiting;
26};
27
28struct bpf_iter_seq_task_info {
29	/* The first field must be struct bpf_iter_seq_task_common.
30	 * this is assumed by {init, fini}_seq_pidns() callback functions.
31	 */
32	struct bpf_iter_seq_task_common common;
33	u32 tid;
34};
35
36static struct task_struct *task_group_seq_get_next(struct bpf_iter_seq_task_common *common,
37						   u32 *tid,
38						   bool skip_if_dup_files)
39{
40	struct task_struct *task;
41	struct pid *pid;
42	u32 next_tid;
43
44	if (!*tid) {
45		/* The first time, the iterator calls this function. */
46		pid = find_pid_ns(common->pid, common->ns);
47		task = get_pid_task(pid, PIDTYPE_TGID);
48		if (!task)
49			return NULL;
50
51		*tid = common->pid;
52		common->pid_visiting = common->pid;
53
54		return task;
55	}
56
57	/* If the control returns to user space and comes back to the
58	 * kernel again, *tid and common->pid_visiting should be the
59	 * same for task_seq_start() to pick up the correct task.
60	 */
61	if (*tid == common->pid_visiting) {
62		pid = find_pid_ns(common->pid_visiting, common->ns);
63		task = get_pid_task(pid, PIDTYPE_PID);
64
65		return task;
66	}
67
68	task = find_task_by_pid_ns(common->pid_visiting, common->ns);
69	if (!task)
70		return NULL;
71
72retry:
73	task = __next_thread(task);
74	if (!task)
75		return NULL;
76
77	next_tid = __task_pid_nr_ns(task, PIDTYPE_PID, common->ns);
78	if (!next_tid)
79		goto retry;
80
81	if (skip_if_dup_files && task->files == task->group_leader->files)
82		goto retry;
83
84	*tid = common->pid_visiting = next_tid;
85	get_task_struct(task);
86	return task;
87}
88
89static struct task_struct *task_seq_get_next(struct bpf_iter_seq_task_common *common,
90					     u32 *tid,
91					     bool skip_if_dup_files)
92{
93	struct task_struct *task = NULL;
94	struct pid *pid;
95
96	if (common->type == BPF_TASK_ITER_TID) {
97		if (*tid && *tid != common->pid)
98			return NULL;
99		rcu_read_lock();
100		pid = find_pid_ns(common->pid, common->ns);
101		if (pid) {
102			task = get_pid_task(pid, PIDTYPE_TGID);
103			*tid = common->pid;
104		}
105		rcu_read_unlock();
106
107		return task;
108	}
109
110	if (common->type == BPF_TASK_ITER_TGID) {
111		rcu_read_lock();
112		task = task_group_seq_get_next(common, tid, skip_if_dup_files);
113		rcu_read_unlock();
114
115		return task;
116	}
117
118	rcu_read_lock();
119retry:
120	pid = find_ge_pid(*tid, common->ns);
121	if (pid) {
122		*tid = pid_nr_ns(pid, common->ns);
123		task = get_pid_task(pid, PIDTYPE_PID);
124		if (!task) {
125			++*tid;
126			goto retry;
127		} else if (skip_if_dup_files && !thread_group_leader(task) &&
128			   task->files == task->group_leader->files) {
129			put_task_struct(task);
130			task = NULL;
131			++*tid;
132			goto retry;
133		}
134	}
135	rcu_read_unlock();
136
137	return task;
138}
139
140static void *task_seq_start(struct seq_file *seq, loff_t *pos)
141{
142	struct bpf_iter_seq_task_info *info = seq->private;
143	struct task_struct *task;
144
145	task = task_seq_get_next(&info->common, &info->tid, false);
146	if (!task)
147		return NULL;
148
149	if (*pos == 0)
150		++*pos;
151	return task;
152}
153
154static void *task_seq_next(struct seq_file *seq, void *v, loff_t *pos)
155{
156	struct bpf_iter_seq_task_info *info = seq->private;
157	struct task_struct *task;
158
159	++*pos;
160	++info->tid;
161	put_task_struct((struct task_struct *)v);
162	task = task_seq_get_next(&info->common, &info->tid, false);
163	if (!task)
164		return NULL;
165
166	return task;
167}
168
169struct bpf_iter__task {
170	__bpf_md_ptr(struct bpf_iter_meta *, meta);
171	__bpf_md_ptr(struct task_struct *, task);
172};
173
174DEFINE_BPF_ITER_FUNC(task, struct bpf_iter_meta *meta, struct task_struct *task)
175
176static int __task_seq_show(struct seq_file *seq, struct task_struct *task,
177			   bool in_stop)
178{
179	struct bpf_iter_meta meta;
180	struct bpf_iter__task ctx;
181	struct bpf_prog *prog;
182
183	meta.seq = seq;
184	prog = bpf_iter_get_info(&meta, in_stop);
185	if (!prog)
186		return 0;
187
188	ctx.meta = &meta;
189	ctx.task = task;
190	return bpf_iter_run_prog(prog, &ctx);
191}
192
193static int task_seq_show(struct seq_file *seq, void *v)
194{
195	return __task_seq_show(seq, v, false);
196}
197
198static void task_seq_stop(struct seq_file *seq, void *v)
199{
200	if (!v)
201		(void)__task_seq_show(seq, v, true);
202	else
203		put_task_struct((struct task_struct *)v);
204}
205
206static int bpf_iter_attach_task(struct bpf_prog *prog,
207				union bpf_iter_link_info *linfo,
208				struct bpf_iter_aux_info *aux)
209{
210	unsigned int flags;
211	struct pid *pid;
212	pid_t tgid;
213
214	if ((!!linfo->task.tid + !!linfo->task.pid + !!linfo->task.pid_fd) > 1)
215		return -EINVAL;
216
217	aux->task.type = BPF_TASK_ITER_ALL;
218	if (linfo->task.tid != 0) {
219		aux->task.type = BPF_TASK_ITER_TID;
220		aux->task.pid = linfo->task.tid;
221	}
222	if (linfo->task.pid != 0) {
223		aux->task.type = BPF_TASK_ITER_TGID;
224		aux->task.pid = linfo->task.pid;
225	}
226	if (linfo->task.pid_fd != 0) {
227		aux->task.type = BPF_TASK_ITER_TGID;
228
229		pid = pidfd_get_pid(linfo->task.pid_fd, &flags);
230		if (IS_ERR(pid))
231			return PTR_ERR(pid);
232
233		tgid = pid_nr_ns(pid, task_active_pid_ns(current));
234		aux->task.pid = tgid;
235		put_pid(pid);
236	}
237
238	return 0;
239}
240
241static const struct seq_operations task_seq_ops = {
242	.start	= task_seq_start,
243	.next	= task_seq_next,
244	.stop	= task_seq_stop,
245	.show	= task_seq_show,
246};
247
248struct bpf_iter_seq_task_file_info {
249	/* The first field must be struct bpf_iter_seq_task_common.
250	 * this is assumed by {init, fini}_seq_pidns() callback functions.
251	 */
252	struct bpf_iter_seq_task_common common;
253	struct task_struct *task;
254	u32 tid;
255	u32 fd;
256};
257
258static struct file *
259task_file_seq_get_next(struct bpf_iter_seq_task_file_info *info)
260{
261	u32 saved_tid = info->tid;
262	struct task_struct *curr_task;
263	unsigned int curr_fd = info->fd;
264
265	/* If this function returns a non-NULL file object,
266	 * it held a reference to the task/file.
267	 * Otherwise, it does not hold any reference.
268	 */
269again:
270	if (info->task) {
271		curr_task = info->task;
272		curr_fd = info->fd;
273	} else {
274		curr_task = task_seq_get_next(&info->common, &info->tid, true);
275                if (!curr_task) {
276                        info->task = NULL;
277                        return NULL;
278                }
279
280		/* set info->task */
281		info->task = curr_task;
282		if (saved_tid == info->tid)
283			curr_fd = info->fd;
284		else
285			curr_fd = 0;
286	}
287
288	rcu_read_lock();
289	for (;; curr_fd++) {
290		struct file *f;
291		f = task_lookup_next_fdget_rcu(curr_task, &curr_fd);
292		if (!f)
293			break;
294
295		/* set info->fd */
296		info->fd = curr_fd;
297		rcu_read_unlock();
298		return f;
299	}
300
301	/* the current task is done, go to the next task */
302	rcu_read_unlock();
303	put_task_struct(curr_task);
304
305	if (info->common.type == BPF_TASK_ITER_TID) {
306		info->task = NULL;
307		return NULL;
308	}
309
310	info->task = NULL;
311	info->fd = 0;
312	saved_tid = ++(info->tid);
313	goto again;
314}
315
316static void *task_file_seq_start(struct seq_file *seq, loff_t *pos)
317{
318	struct bpf_iter_seq_task_file_info *info = seq->private;
319	struct file *file;
320
321	info->task = NULL;
322	file = task_file_seq_get_next(info);
323	if (file && *pos == 0)
324		++*pos;
325
326	return file;
327}
328
329static void *task_file_seq_next(struct seq_file *seq, void *v, loff_t *pos)
330{
331	struct bpf_iter_seq_task_file_info *info = seq->private;
332
333	++*pos;
334	++info->fd;
335	fput((struct file *)v);
336	return task_file_seq_get_next(info);
337}
338
339struct bpf_iter__task_file {
340	__bpf_md_ptr(struct bpf_iter_meta *, meta);
341	__bpf_md_ptr(struct task_struct *, task);
342	u32 fd __aligned(8);
343	__bpf_md_ptr(struct file *, file);
344};
345
346DEFINE_BPF_ITER_FUNC(task_file, struct bpf_iter_meta *meta,
347		     struct task_struct *task, u32 fd,
348		     struct file *file)
349
350static int __task_file_seq_show(struct seq_file *seq, struct file *file,
351				bool in_stop)
352{
353	struct bpf_iter_seq_task_file_info *info = seq->private;
354	struct bpf_iter__task_file ctx;
355	struct bpf_iter_meta meta;
356	struct bpf_prog *prog;
357
358	meta.seq = seq;
359	prog = bpf_iter_get_info(&meta, in_stop);
360	if (!prog)
361		return 0;
362
363	ctx.meta = &meta;
364	ctx.task = info->task;
365	ctx.fd = info->fd;
366	ctx.file = file;
367	return bpf_iter_run_prog(prog, &ctx);
368}
369
370static int task_file_seq_show(struct seq_file *seq, void *v)
371{
372	return __task_file_seq_show(seq, v, false);
373}
374
375static void task_file_seq_stop(struct seq_file *seq, void *v)
376{
377	struct bpf_iter_seq_task_file_info *info = seq->private;
378
379	if (!v) {
380		(void)__task_file_seq_show(seq, v, true);
381	} else {
382		fput((struct file *)v);
383		put_task_struct(info->task);
384		info->task = NULL;
385	}
386}
387
388static int init_seq_pidns(void *priv_data, struct bpf_iter_aux_info *aux)
389{
390	struct bpf_iter_seq_task_common *common = priv_data;
391
392	common->ns = get_pid_ns(task_active_pid_ns(current));
393	common->type = aux->task.type;
394	common->pid = aux->task.pid;
395
396	return 0;
397}
398
399static void fini_seq_pidns(void *priv_data)
400{
401	struct bpf_iter_seq_task_common *common = priv_data;
402
403	put_pid_ns(common->ns);
404}
405
406static const struct seq_operations task_file_seq_ops = {
407	.start	= task_file_seq_start,
408	.next	= task_file_seq_next,
409	.stop	= task_file_seq_stop,
410	.show	= task_file_seq_show,
411};
412
413struct bpf_iter_seq_task_vma_info {
414	/* The first field must be struct bpf_iter_seq_task_common.
415	 * this is assumed by {init, fini}_seq_pidns() callback functions.
416	 */
417	struct bpf_iter_seq_task_common common;
418	struct task_struct *task;
419	struct mm_struct *mm;
420	struct vm_area_struct *vma;
421	u32 tid;
422	unsigned long prev_vm_start;
423	unsigned long prev_vm_end;
424};
425
426enum bpf_task_vma_iter_find_op {
427	task_vma_iter_first_vma,   /* use find_vma() with addr 0 */
428	task_vma_iter_next_vma,    /* use vma_next() with curr_vma */
429	task_vma_iter_find_vma,    /* use find_vma() to find next vma */
430};
431
432static struct vm_area_struct *
433task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
434{
435	enum bpf_task_vma_iter_find_op op;
436	struct vm_area_struct *curr_vma;
437	struct task_struct *curr_task;
438	struct mm_struct *curr_mm;
439	u32 saved_tid = info->tid;
440
441	/* If this function returns a non-NULL vma, it holds a reference to
442	 * the task_struct, holds a refcount on mm->mm_users, and holds
443	 * read lock on vma->mm->mmap_lock.
444	 * If this function returns NULL, it does not hold any reference or
445	 * lock.
446	 */
447	if (info->task) {
448		curr_task = info->task;
449		curr_vma = info->vma;
450		curr_mm = info->mm;
451		/* In case of lock contention, drop mmap_lock to unblock
452		 * the writer.
453		 *
454		 * After relock, call find(mm, prev_vm_end - 1) to find
455		 * new vma to process.
456		 *
457		 *   +------+------+-----------+
458		 *   | VMA1 | VMA2 | VMA3      |
459		 *   +------+------+-----------+
460		 *   |      |      |           |
461		 *  4k     8k     16k         400k
462		 *
463		 * For example, curr_vma == VMA2. Before unlock, we set
464		 *
465		 *    prev_vm_start = 8k
466		 *    prev_vm_end   = 16k
467		 *
468		 * There are a few cases:
469		 *
470		 * 1) VMA2 is freed, but VMA3 exists.
471		 *
472		 *    find_vma() will return VMA3, just process VMA3.
473		 *
474		 * 2) VMA2 still exists.
475		 *
476		 *    find_vma() will return VMA2, process VMA2->next.
477		 *
478		 * 3) no more vma in this mm.
479		 *
480		 *    Process the next task.
481		 *
482		 * 4) find_vma() returns a different vma, VMA2'.
483		 *
484		 *    4.1) If VMA2 covers same range as VMA2', skip VMA2',
485		 *         because we already covered the range;
486		 *    4.2) VMA2 and VMA2' covers different ranges, process
487		 *         VMA2'.
488		 */
489		if (mmap_lock_is_contended(curr_mm)) {
490			info->prev_vm_start = curr_vma->vm_start;
491			info->prev_vm_end = curr_vma->vm_end;
492			op = task_vma_iter_find_vma;
493			mmap_read_unlock(curr_mm);
494			if (mmap_read_lock_killable(curr_mm)) {
495				mmput(curr_mm);
496				goto finish;
497			}
498		} else {
499			op = task_vma_iter_next_vma;
500		}
501	} else {
502again:
503		curr_task = task_seq_get_next(&info->common, &info->tid, true);
504		if (!curr_task) {
505			info->tid++;
506			goto finish;
507		}
508
509		if (saved_tid != info->tid) {
510			/* new task, process the first vma */
511			op = task_vma_iter_first_vma;
512		} else {
513			/* Found the same tid, which means the user space
514			 * finished data in previous buffer and read more.
515			 * We dropped mmap_lock before returning to user
516			 * space, so it is necessary to use find_vma() to
517			 * find the next vma to process.
518			 */
519			op = task_vma_iter_find_vma;
520		}
521
522		curr_mm = get_task_mm(curr_task);
523		if (!curr_mm)
524			goto next_task;
525
526		if (mmap_read_lock_killable(curr_mm)) {
527			mmput(curr_mm);
528			goto finish;
529		}
530	}
531
532	switch (op) {
533	case task_vma_iter_first_vma:
534		curr_vma = find_vma(curr_mm, 0);
535		break;
536	case task_vma_iter_next_vma:
537		curr_vma = find_vma(curr_mm, curr_vma->vm_end);
538		break;
539	case task_vma_iter_find_vma:
540		/* We dropped mmap_lock so it is necessary to use find_vma
541		 * to find the next vma. This is similar to the  mechanism
542		 * in show_smaps_rollup().
543		 */
544		curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
545		/* case 1) and 4.2) above just use curr_vma */
546
547		/* check for case 2) or case 4.1) above */
548		if (curr_vma &&
549		    curr_vma->vm_start == info->prev_vm_start &&
550		    curr_vma->vm_end == info->prev_vm_end)
551			curr_vma = find_vma(curr_mm, curr_vma->vm_end);
552		break;
553	}
554	if (!curr_vma) {
555		/* case 3) above, or case 2) 4.1) with vma->next == NULL */
556		mmap_read_unlock(curr_mm);
557		mmput(curr_mm);
558		goto next_task;
559	}
560	info->task = curr_task;
561	info->vma = curr_vma;
562	info->mm = curr_mm;
563	return curr_vma;
564
565next_task:
566	if (info->common.type == BPF_TASK_ITER_TID)
567		goto finish;
568
569	put_task_struct(curr_task);
570	info->task = NULL;
571	info->mm = NULL;
572	info->tid++;
573	goto again;
574
575finish:
576	if (curr_task)
577		put_task_struct(curr_task);
578	info->task = NULL;
579	info->vma = NULL;
580	info->mm = NULL;
581	return NULL;
582}
583
584static void *task_vma_seq_start(struct seq_file *seq, loff_t *pos)
585{
586	struct bpf_iter_seq_task_vma_info *info = seq->private;
587	struct vm_area_struct *vma;
588
589	vma = task_vma_seq_get_next(info);
590	if (vma && *pos == 0)
591		++*pos;
592
593	return vma;
594}
595
596static void *task_vma_seq_next(struct seq_file *seq, void *v, loff_t *pos)
597{
598	struct bpf_iter_seq_task_vma_info *info = seq->private;
599
600	++*pos;
601	return task_vma_seq_get_next(info);
602}
603
604struct bpf_iter__task_vma {
605	__bpf_md_ptr(struct bpf_iter_meta *, meta);
606	__bpf_md_ptr(struct task_struct *, task);
607	__bpf_md_ptr(struct vm_area_struct *, vma);
608};
609
610DEFINE_BPF_ITER_FUNC(task_vma, struct bpf_iter_meta *meta,
611		     struct task_struct *task, struct vm_area_struct *vma)
612
613static int __task_vma_seq_show(struct seq_file *seq, bool in_stop)
614{
615	struct bpf_iter_seq_task_vma_info *info = seq->private;
616	struct bpf_iter__task_vma ctx;
617	struct bpf_iter_meta meta;
618	struct bpf_prog *prog;
619
620	meta.seq = seq;
621	prog = bpf_iter_get_info(&meta, in_stop);
622	if (!prog)
623		return 0;
624
625	ctx.meta = &meta;
626	ctx.task = info->task;
627	ctx.vma = info->vma;
628	return bpf_iter_run_prog(prog, &ctx);
629}
630
631static int task_vma_seq_show(struct seq_file *seq, void *v)
632{
633	return __task_vma_seq_show(seq, false);
634}
635
636static void task_vma_seq_stop(struct seq_file *seq, void *v)
637{
638	struct bpf_iter_seq_task_vma_info *info = seq->private;
639
640	if (!v) {
641		(void)__task_vma_seq_show(seq, true);
642	} else {
643		/* info->vma has not been seen by the BPF program. If the
644		 * user space reads more, task_vma_seq_get_next should
645		 * return this vma again. Set prev_vm_start to ~0UL,
646		 * so that we don't skip the vma returned by the next
647		 * find_vma() (case task_vma_iter_find_vma in
648		 * task_vma_seq_get_next()).
649		 */
650		info->prev_vm_start = ~0UL;
651		info->prev_vm_end = info->vma->vm_end;
652		mmap_read_unlock(info->mm);
653		mmput(info->mm);
654		info->mm = NULL;
655		put_task_struct(info->task);
656		info->task = NULL;
657	}
658}
659
660static const struct seq_operations task_vma_seq_ops = {
661	.start	= task_vma_seq_start,
662	.next	= task_vma_seq_next,
663	.stop	= task_vma_seq_stop,
664	.show	= task_vma_seq_show,
665};
666
667static const struct bpf_iter_seq_info task_seq_info = {
668	.seq_ops		= &task_seq_ops,
669	.init_seq_private	= init_seq_pidns,
670	.fini_seq_private	= fini_seq_pidns,
671	.seq_priv_size		= sizeof(struct bpf_iter_seq_task_info),
672};
673
674static int bpf_iter_fill_link_info(const struct bpf_iter_aux_info *aux, struct bpf_link_info *info)
675{
676	switch (aux->task.type) {
677	case BPF_TASK_ITER_TID:
678		info->iter.task.tid = aux->task.pid;
679		break;
680	case BPF_TASK_ITER_TGID:
681		info->iter.task.pid = aux->task.pid;
682		break;
683	default:
684		break;
685	}
686	return 0;
687}
688
689static void bpf_iter_task_show_fdinfo(const struct bpf_iter_aux_info *aux, struct seq_file *seq)
690{
691	seq_printf(seq, "task_type:\t%s\n", iter_task_type_names[aux->task.type]);
692	if (aux->task.type == BPF_TASK_ITER_TID)
693		seq_printf(seq, "tid:\t%u\n", aux->task.pid);
694	else if (aux->task.type == BPF_TASK_ITER_TGID)
695		seq_printf(seq, "pid:\t%u\n", aux->task.pid);
696}
697
698static struct bpf_iter_reg task_reg_info = {
699	.target			= "task",
700	.attach_target		= bpf_iter_attach_task,
701	.feature		= BPF_ITER_RESCHED,
702	.ctx_arg_info_size	= 1,
703	.ctx_arg_info		= {
704		{ offsetof(struct bpf_iter__task, task),
705		  PTR_TO_BTF_ID_OR_NULL | PTR_TRUSTED },
706	},
707	.seq_info		= &task_seq_info,
708	.fill_link_info		= bpf_iter_fill_link_info,
709	.show_fdinfo		= bpf_iter_task_show_fdinfo,
710};
711
712static const struct bpf_iter_seq_info task_file_seq_info = {
713	.seq_ops		= &task_file_seq_ops,
714	.init_seq_private	= init_seq_pidns,
715	.fini_seq_private	= fini_seq_pidns,
716	.seq_priv_size		= sizeof(struct bpf_iter_seq_task_file_info),
717};
718
719static struct bpf_iter_reg task_file_reg_info = {
720	.target			= "task_file",
721	.attach_target		= bpf_iter_attach_task,
722	.feature		= BPF_ITER_RESCHED,
723	.ctx_arg_info_size	= 2,
724	.ctx_arg_info		= {
725		{ offsetof(struct bpf_iter__task_file, task),
726		  PTR_TO_BTF_ID_OR_NULL },
727		{ offsetof(struct bpf_iter__task_file, file),
728		  PTR_TO_BTF_ID_OR_NULL },
729	},
730	.seq_info		= &task_file_seq_info,
731	.fill_link_info		= bpf_iter_fill_link_info,
732	.show_fdinfo		= bpf_iter_task_show_fdinfo,
733};
734
735static const struct bpf_iter_seq_info task_vma_seq_info = {
736	.seq_ops		= &task_vma_seq_ops,
737	.init_seq_private	= init_seq_pidns,
738	.fini_seq_private	= fini_seq_pidns,
739	.seq_priv_size		= sizeof(struct bpf_iter_seq_task_vma_info),
740};
741
742static struct bpf_iter_reg task_vma_reg_info = {
743	.target			= "task_vma",
744	.attach_target		= bpf_iter_attach_task,
745	.feature		= BPF_ITER_RESCHED,
746	.ctx_arg_info_size	= 2,
747	.ctx_arg_info		= {
748		{ offsetof(struct bpf_iter__task_vma, task),
749		  PTR_TO_BTF_ID_OR_NULL },
750		{ offsetof(struct bpf_iter__task_vma, vma),
751		  PTR_TO_BTF_ID_OR_NULL },
752	},
753	.seq_info		= &task_vma_seq_info,
754	.fill_link_info		= bpf_iter_fill_link_info,
755	.show_fdinfo		= bpf_iter_task_show_fdinfo,
756};
757
758BPF_CALL_5(bpf_find_vma, struct task_struct *, task, u64, start,
759	   bpf_callback_t, callback_fn, void *, callback_ctx, u64, flags)
760{
761	struct mmap_unlock_irq_work *work = NULL;
762	struct vm_area_struct *vma;
763	bool irq_work_busy = false;
764	struct mm_struct *mm;
765	int ret = -ENOENT;
766
767	if (flags)
768		return -EINVAL;
769
770	if (!task)
771		return -ENOENT;
772
773	mm = task->mm;
774	if (!mm)
775		return -ENOENT;
776
777	irq_work_busy = bpf_mmap_unlock_get_irq_work(&work);
778
779	if (irq_work_busy || !mmap_read_trylock(mm))
780		return -EBUSY;
781
782	vma = find_vma(mm, start);
783
784	if (vma && vma->vm_start <= start && vma->vm_end > start) {
785		callback_fn((u64)(long)task, (u64)(long)vma,
786			    (u64)(long)callback_ctx, 0, 0);
787		ret = 0;
788	}
789	bpf_mmap_unlock_mm(work, mm);
790	return ret;
791}
792
793const struct bpf_func_proto bpf_find_vma_proto = {
794	.func		= bpf_find_vma,
795	.ret_type	= RET_INTEGER,
796	.arg1_type	= ARG_PTR_TO_BTF_ID,
797	.arg1_btf_id	= &btf_tracing_ids[BTF_TRACING_TYPE_TASK],
798	.arg2_type	= ARG_ANYTHING,
799	.arg3_type	= ARG_PTR_TO_FUNC,
800	.arg4_type	= ARG_PTR_TO_STACK_OR_NULL,
801	.arg5_type	= ARG_ANYTHING,
802};
803
804struct bpf_iter_task_vma_kern_data {
805	struct task_struct *task;
806	struct mm_struct *mm;
807	struct mmap_unlock_irq_work *work;
808	struct vma_iterator vmi;
809};
810
811struct bpf_iter_task_vma {
812	/* opaque iterator state; having __u64 here allows to preserve correct
813	 * alignment requirements in vmlinux.h, generated from BTF
814	 */
815	__u64 __opaque[1];
816} __attribute__((aligned(8)));
817
818/* Non-opaque version of bpf_iter_task_vma */
819struct bpf_iter_task_vma_kern {
820	struct bpf_iter_task_vma_kern_data *data;
821} __attribute__((aligned(8)));
822
823__bpf_kfunc_start_defs();
824
825__bpf_kfunc int bpf_iter_task_vma_new(struct bpf_iter_task_vma *it,
826				      struct task_struct *task, u64 addr)
827{
828	struct bpf_iter_task_vma_kern *kit = (void *)it;
829	bool irq_work_busy = false;
830	int err;
831
832	BUILD_BUG_ON(sizeof(struct bpf_iter_task_vma_kern) != sizeof(struct bpf_iter_task_vma));
833	BUILD_BUG_ON(__alignof__(struct bpf_iter_task_vma_kern) != __alignof__(struct bpf_iter_task_vma));
834
835	/* is_iter_reg_valid_uninit guarantees that kit hasn't been initialized
836	 * before, so non-NULL kit->data doesn't point to previously
837	 * bpf_mem_alloc'd bpf_iter_task_vma_kern_data
838	 */
839	kit->data = bpf_mem_alloc(&bpf_global_ma, sizeof(struct bpf_iter_task_vma_kern_data));
840	if (!kit->data)
841		return -ENOMEM;
842
843	kit->data->task = get_task_struct(task);
844	kit->data->mm = task->mm;
845	if (!kit->data->mm) {
846		err = -ENOENT;
847		goto err_cleanup_iter;
848	}
849
850	/* kit->data->work == NULL is valid after bpf_mmap_unlock_get_irq_work */
851	irq_work_busy = bpf_mmap_unlock_get_irq_work(&kit->data->work);
852	if (irq_work_busy || !mmap_read_trylock(kit->data->mm)) {
853		err = -EBUSY;
854		goto err_cleanup_iter;
855	}
856
857	vma_iter_init(&kit->data->vmi, kit->data->mm, addr);
858	return 0;
859
860err_cleanup_iter:
861	if (kit->data->task)
862		put_task_struct(kit->data->task);
863	bpf_mem_free(&bpf_global_ma, kit->data);
864	/* NULL kit->data signals failed bpf_iter_task_vma initialization */
865	kit->data = NULL;
866	return err;
867}
868
869__bpf_kfunc struct vm_area_struct *bpf_iter_task_vma_next(struct bpf_iter_task_vma *it)
870{
871	struct bpf_iter_task_vma_kern *kit = (void *)it;
872
873	if (!kit->data) /* bpf_iter_task_vma_new failed */
874		return NULL;
875	return vma_next(&kit->data->vmi);
876}
877
878__bpf_kfunc void bpf_iter_task_vma_destroy(struct bpf_iter_task_vma *it)
879{
880	struct bpf_iter_task_vma_kern *kit = (void *)it;
881
882	if (kit->data) {
883		bpf_mmap_unlock_mm(kit->data->work, kit->data->mm);
884		put_task_struct(kit->data->task);
885		bpf_mem_free(&bpf_global_ma, kit->data);
886	}
887}
888
889__bpf_kfunc_end_defs();
890
891#ifdef CONFIG_CGROUPS
892
893struct bpf_iter_css_task {
894	__u64 __opaque[1];
895} __attribute__((aligned(8)));
896
897struct bpf_iter_css_task_kern {
898	struct css_task_iter *css_it;
899} __attribute__((aligned(8)));
900
901__bpf_kfunc_start_defs();
902
903__bpf_kfunc int bpf_iter_css_task_new(struct bpf_iter_css_task *it,
904		struct cgroup_subsys_state *css, unsigned int flags)
905{
906	struct bpf_iter_css_task_kern *kit = (void *)it;
907
908	BUILD_BUG_ON(sizeof(struct bpf_iter_css_task_kern) != sizeof(struct bpf_iter_css_task));
909	BUILD_BUG_ON(__alignof__(struct bpf_iter_css_task_kern) !=
910					__alignof__(struct bpf_iter_css_task));
911	kit->css_it = NULL;
912	switch (flags) {
913	case CSS_TASK_ITER_PROCS | CSS_TASK_ITER_THREADED:
914	case CSS_TASK_ITER_PROCS:
915	case 0:
916		break;
917	default:
918		return -EINVAL;
919	}
920
921	kit->css_it = bpf_mem_alloc(&bpf_global_ma, sizeof(struct css_task_iter));
922	if (!kit->css_it)
923		return -ENOMEM;
924	css_task_iter_start(css, flags, kit->css_it);
925	return 0;
926}
927
928__bpf_kfunc struct task_struct *bpf_iter_css_task_next(struct bpf_iter_css_task *it)
929{
930	struct bpf_iter_css_task_kern *kit = (void *)it;
931
932	if (!kit->css_it)
933		return NULL;
934	return css_task_iter_next(kit->css_it);
935}
936
937__bpf_kfunc void bpf_iter_css_task_destroy(struct bpf_iter_css_task *it)
938{
939	struct bpf_iter_css_task_kern *kit = (void *)it;
940
941	if (!kit->css_it)
942		return;
943	css_task_iter_end(kit->css_it);
944	bpf_mem_free(&bpf_global_ma, kit->css_it);
945}
946
947__bpf_kfunc_end_defs();
948
949#endif /* CONFIG_CGROUPS */
950
951struct bpf_iter_task {
952	__u64 __opaque[3];
953} __attribute__((aligned(8)));
954
955struct bpf_iter_task_kern {
956	struct task_struct *task;
957	struct task_struct *pos;
958	unsigned int flags;
959} __attribute__((aligned(8)));
960
961enum {
962	/* all process in the system */
963	BPF_TASK_ITER_ALL_PROCS,
964	/* all threads in the system */
965	BPF_TASK_ITER_ALL_THREADS,
966	/* all threads of a specific process */
967	BPF_TASK_ITER_PROC_THREADS
968};
969
970__bpf_kfunc_start_defs();
971
972__bpf_kfunc int bpf_iter_task_new(struct bpf_iter_task *it,
973		struct task_struct *task__nullable, unsigned int flags)
974{
975	struct bpf_iter_task_kern *kit = (void *)it;
976
977	BUILD_BUG_ON(sizeof(struct bpf_iter_task_kern) > sizeof(struct bpf_iter_task));
978	BUILD_BUG_ON(__alignof__(struct bpf_iter_task_kern) !=
979					__alignof__(struct bpf_iter_task));
980
981	kit->pos = NULL;
982
983	switch (flags) {
984	case BPF_TASK_ITER_ALL_THREADS:
985	case BPF_TASK_ITER_ALL_PROCS:
986		break;
987	case BPF_TASK_ITER_PROC_THREADS:
988		if (!task__nullable)
989			return -EINVAL;
990		break;
991	default:
992		return -EINVAL;
993	}
994
995	if (flags == BPF_TASK_ITER_PROC_THREADS)
996		kit->task = task__nullable;
997	else
998		kit->task = &init_task;
999	kit->pos = kit->task;
1000	kit->flags = flags;
1001	return 0;
1002}
1003
1004__bpf_kfunc struct task_struct *bpf_iter_task_next(struct bpf_iter_task *it)
1005{
1006	struct bpf_iter_task_kern *kit = (void *)it;
1007	struct task_struct *pos;
1008	unsigned int flags;
1009
1010	flags = kit->flags;
1011	pos = kit->pos;
1012
1013	if (!pos)
1014		return pos;
1015
1016	if (flags == BPF_TASK_ITER_ALL_PROCS)
1017		goto get_next_task;
1018
1019	kit->pos = __next_thread(kit->pos);
1020	if (kit->pos || flags == BPF_TASK_ITER_PROC_THREADS)
1021		return pos;
1022
1023get_next_task:
1024	kit->task = next_task(kit->task);
1025	if (kit->task == &init_task)
1026		kit->pos = NULL;
1027	else
1028		kit->pos = kit->task;
1029
1030	return pos;
1031}
1032
1033__bpf_kfunc void bpf_iter_task_destroy(struct bpf_iter_task *it)
1034{
1035}
1036
1037__bpf_kfunc_end_defs();
1038
1039DEFINE_PER_CPU(struct mmap_unlock_irq_work, mmap_unlock_work);
1040
1041static void do_mmap_read_unlock(struct irq_work *entry)
1042{
1043	struct mmap_unlock_irq_work *work;
1044
1045	if (WARN_ON_ONCE(IS_ENABLED(CONFIG_PREEMPT_RT)))
1046		return;
1047
1048	work = container_of(entry, struct mmap_unlock_irq_work, irq_work);
1049	mmap_read_unlock_non_owner(work->mm);
1050}
1051
1052static int __init task_iter_init(void)
1053{
1054	struct mmap_unlock_irq_work *work;
1055	int ret, cpu;
1056
1057	for_each_possible_cpu(cpu) {
1058		work = per_cpu_ptr(&mmap_unlock_work, cpu);
1059		init_irq_work(&work->irq_work, do_mmap_read_unlock);
1060	}
1061
1062	task_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1063	ret = bpf_iter_reg_target(&task_reg_info);
1064	if (ret)
1065		return ret;
1066
1067	task_file_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1068	task_file_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_FILE];
1069	ret =  bpf_iter_reg_target(&task_file_reg_info);
1070	if (ret)
1071		return ret;
1072
1073	task_vma_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1074	task_vma_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_VMA];
1075	return bpf_iter_reg_target(&task_vma_reg_info);
1076}
1077late_initcall(task_iter_init);
1078