1/* SPDX-License-Identifier: GPL-2.0 */
2#ifndef MEAN_AND_VARIANCE_H_
3#define MEAN_AND_VARIANCE_H_
4
5#include <linux/types.h>
6#include <linux/limits.h>
7#include <linux/math.h>
8#include <linux/math64.h>
9
10#define SQRT_U64_MAX 4294967295ULL
11
12/*
13 * u128_u: u128 user mode, because not all architectures support a real int128
14 * type
15 *
16 * We don't use this version in userspace, because in userspace we link with
17 * Rust and rustc has issues with u128.
18 */
19
20#if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC)
21
22typedef struct {
23	unsigned __int128 v;
24} __aligned(16) u128_u;
25
26static inline u128_u u64_to_u128(u64 a)
27{
28	return (u128_u) { .v = a };
29}
30
31static inline u64 u128_lo(u128_u a)
32{
33	return a.v;
34}
35
36static inline u64 u128_hi(u128_u a)
37{
38	return a.v >> 64;
39}
40
41static inline u128_u u128_add(u128_u a, u128_u b)
42{
43	a.v += b.v;
44	return a;
45}
46
47static inline u128_u u128_sub(u128_u a, u128_u b)
48{
49	a.v -= b.v;
50	return a;
51}
52
53static inline u128_u u128_shl(u128_u a, s8 shift)
54{
55	a.v <<= shift;
56	return a;
57}
58
59static inline u128_u u128_square(u64 a)
60{
61	u128_u b = u64_to_u128(a);
62
63	b.v *= b.v;
64	return b;
65}
66
67#else
68
69typedef struct {
70	u64 hi, lo;
71} __aligned(16) u128_u;
72
73/* conversions */
74
75static inline u128_u u64_to_u128(u64 a)
76{
77	return (u128_u) { .lo = a };
78}
79
80static inline u64 u128_lo(u128_u a)
81{
82	return a.lo;
83}
84
85static inline u64 u128_hi(u128_u a)
86{
87	return a.hi;
88}
89
90/* arithmetic */
91
92static inline u128_u u128_add(u128_u a, u128_u b)
93{
94	u128_u c;
95
96	c.lo = a.lo + b.lo;
97	c.hi = a.hi + b.hi + (c.lo < a.lo);
98	return c;
99}
100
101static inline u128_u u128_sub(u128_u a, u128_u b)
102{
103	u128_u c;
104
105	c.lo = a.lo - b.lo;
106	c.hi = a.hi - b.hi - (c.lo > a.lo);
107	return c;
108}
109
110static inline u128_u u128_shl(u128_u i, s8 shift)
111{
112	u128_u r;
113
114	r.lo = i.lo << shift;
115	if (shift < 64)
116		r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
117	else {
118		r.hi = i.lo << (shift - 64);
119		r.lo = 0;
120	}
121	return r;
122}
123
124static inline u128_u u128_square(u64 i)
125{
126	u128_u r;
127	u64  h = i >> 32, l = i & U32_MAX;
128
129	r =             u128_shl(u64_to_u128(h*h), 64);
130	r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
131	r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
132	r = u128_add(r,          u64_to_u128(l*l));
133	return r;
134}
135
136#endif
137
138static inline u128_u u64s_to_u128(u64 hi, u64 lo)
139{
140	u128_u c = u64_to_u128(hi);
141
142	c = u128_shl(c, 64);
143	c = u128_add(c, u64_to_u128(lo));
144	return c;
145}
146
147u128_u u128_div(u128_u n, u64 d);
148
149struct mean_and_variance {
150	s64	n;
151	s64	sum;
152	u128_u	sum_squares;
153};
154
155/* expontentially weighted variant */
156struct mean_and_variance_weighted {
157	s64	mean;
158	u64	variance;
159};
160
161/**
162 * fast_divpow2() - fast approximation for n / (1 << d)
163 * @n: numerator
164 * @d: the power of 2 denominator.
165 *
166 * note: this rounds towards 0.
167 */
168static inline s64 fast_divpow2(s64 n, u8 d)
169{
170	return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
171}
172
173/**
174 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
175 * and return it.
176 * @s1: the mean_and_variance to update.
177 * @v1: the new sample.
178 *
179 * see linked pdf equation 12.
180 */
181static inline void
182mean_and_variance_update(struct mean_and_variance *s, s64 v)
183{
184	s->n++;
185	s->sum += v;
186	s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v)));
187}
188
189s64 mean_and_variance_get_mean(struct mean_and_variance s);
190u64 mean_and_variance_get_variance(struct mean_and_variance s1);
191u32 mean_and_variance_get_stddev(struct mean_and_variance s);
192
193void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s,
194		s64 v, bool initted, u8 weight);
195
196s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s,
197		u8 weight);
198u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s,
199		u8 weight);
200u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s,
201		u8 weight);
202
203#endif // MEAN_AND_VAIRANCE_H_
204