1// SPDX-License-Identifier: GPL-2.0
2
3#include <linux/ptrace.h>
4#include <stddef.h>
5#include <linux/bpf.h>
6#include <bpf/bpf_helpers.h>
7#include <bpf/bpf_tracing.h>
8#include "bpf_misc.h"
9
10char _license[] SEC("license") = "GPL";
11
12/* typically virtio scsi has max SGs of 6 */
13#define VIRTIO_MAX_SGS	6
14
15/* Verifier will fail with SG_MAX = 128. The failure can be
16 * workarounded with a smaller SG_MAX, e.g. 10.
17 */
18#define WORKAROUND
19#ifdef WORKAROUND
20#define SG_MAX		10
21#else
22/* typically virtio blk has max SEG of 128 */
23#define SG_MAX		128
24#endif
25
26#define SG_CHAIN	0x01UL
27#define SG_END		0x02UL
28
29struct scatterlist {
30	unsigned long   page_link;
31	unsigned int    offset;
32	unsigned int    length;
33};
34
35#define sg_is_chain(sg)		((sg)->page_link & SG_CHAIN)
36#define sg_is_last(sg)		((sg)->page_link & SG_END)
37#define sg_chain_ptr(sg)	\
38	((struct scatterlist *) ((sg)->page_link & ~(SG_CHAIN | SG_END)))
39
40static inline struct scatterlist *__sg_next(struct scatterlist *sgp)
41{
42	struct scatterlist sg;
43
44	bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
45	if (sg_is_last(&sg))
46		return NULL;
47
48	sgp++;
49
50	bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
51	if (sg_is_chain(&sg))
52		sgp = sg_chain_ptr(&sg);
53
54	return sgp;
55}
56
57static inline struct scatterlist *get_sgp(struct scatterlist **sgs, int i)
58{
59	struct scatterlist *sgp;
60
61	bpf_probe_read_kernel(&sgp, sizeof(sgp), sgs + i);
62	return sgp;
63}
64
65int config = 0;
66int result = 0;
67
68SEC("kprobe/virtqueue_add_sgs")
69int BPF_KPROBE(trace_virtqueue_add_sgs, void *unused, struct scatterlist **sgs,
70	       unsigned int out_sgs, unsigned int in_sgs)
71{
72	struct scatterlist *sgp = NULL;
73	__u64 length1 = 0, length2 = 0;
74	unsigned int i, n, len;
75
76	if (config != 0)
77		return 0;
78
79	for (i = 0; (i < VIRTIO_MAX_SGS) && (i < out_sgs); i++) {
80		__sink(out_sgs);
81		for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
82		     sgp = __sg_next(sgp)) {
83			bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
84			length1 += len;
85			n++;
86		}
87	}
88
89	for (i = 0; (i < VIRTIO_MAX_SGS) && (i < in_sgs); i++) {
90		__sink(in_sgs);
91		for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
92		     sgp = __sg_next(sgp)) {
93			bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
94			length2 += len;
95			n++;
96		}
97	}
98
99	config = 1;
100	result = length2 - length1;
101	return 0;
102}
103