1/* SPDX-License-Identifier: GPL-2.0 */ 2#ifndef MEAN_AND_VARIANCE_H_ 3#define MEAN_AND_VARIANCE_H_ 4 5#include <linux/types.h> 6#include <linux/limits.h> 7#include <linux/math.h> 8#include <linux/math64.h> 9 10#define SQRT_U64_MAX 4294967295ULL 11 12/* 13 * u128_u: u128 user mode, because not all architectures support a real int128 14 * type 15 * 16 * We don't use this version in userspace, because in userspace we link with 17 * Rust and rustc has issues with u128. 18 */ 19 20#if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC) 21 22typedef struct { 23 unsigned __int128 v; 24} __aligned(16) u128_u; 25 26static inline u128_u u64_to_u128(u64 a) 27{ 28 return (u128_u) { .v = a }; 29} 30 31static inline u64 u128_lo(u128_u a) 32{ 33 return a.v; 34} 35 36static inline u64 u128_hi(u128_u a) 37{ 38 return a.v >> 64; 39} 40 41static inline u128_u u128_add(u128_u a, u128_u b) 42{ 43 a.v += b.v; 44 return a; 45} 46 47static inline u128_u u128_sub(u128_u a, u128_u b) 48{ 49 a.v -= b.v; 50 return a; 51} 52 53static inline u128_u u128_shl(u128_u a, s8 shift) 54{ 55 a.v <<= shift; 56 return a; 57} 58 59static inline u128_u u128_square(u64 a) 60{ 61 u128_u b = u64_to_u128(a); 62 63 b.v *= b.v; 64 return b; 65} 66 67#else 68 69typedef struct { 70 u64 hi, lo; 71} __aligned(16) u128_u; 72 73/* conversions */ 74 75static inline u128_u u64_to_u128(u64 a) 76{ 77 return (u128_u) { .lo = a }; 78} 79 80static inline u64 u128_lo(u128_u a) 81{ 82 return a.lo; 83} 84 85static inline u64 u128_hi(u128_u a) 86{ 87 return a.hi; 88} 89 90/* arithmetic */ 91 92static inline u128_u u128_add(u128_u a, u128_u b) 93{ 94 u128_u c; 95 96 c.lo = a.lo + b.lo; 97 c.hi = a.hi + b.hi + (c.lo < a.lo); 98 return c; 99} 100 101static inline u128_u u128_sub(u128_u a, u128_u b) 102{ 103 u128_u c; 104 105 c.lo = a.lo - b.lo; 106 c.hi = a.hi - b.hi - (c.lo > a.lo); 107 return c; 108} 109 110static inline u128_u u128_shl(u128_u i, s8 shift) 111{ 112 u128_u r; 113 114 r.lo = i.lo << shift; 115 if (shift < 64) 116 r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); 117 else { 118 r.hi = i.lo << (shift - 64); 119 r.lo = 0; 120 } 121 return r; 122} 123 124static inline u128_u u128_square(u64 i) 125{ 126 u128_u r; 127 u64 h = i >> 32, l = i & U32_MAX; 128 129 r = u128_shl(u64_to_u128(h*h), 64); 130 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); 131 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); 132 r = u128_add(r, u64_to_u128(l*l)); 133 return r; 134} 135 136#endif 137 138static inline u128_u u64s_to_u128(u64 hi, u64 lo) 139{ 140 u128_u c = u64_to_u128(hi); 141 142 c = u128_shl(c, 64); 143 c = u128_add(c, u64_to_u128(lo)); 144 return c; 145} 146 147u128_u u128_div(u128_u n, u64 d); 148 149struct mean_and_variance { 150 s64 n; 151 s64 sum; 152 u128_u sum_squares; 153}; 154 155/* expontentially weighted variant */ 156struct mean_and_variance_weighted { 157 s64 mean; 158 u64 variance; 159}; 160 161/** 162 * fast_divpow2() - fast approximation for n / (1 << d) 163 * @n: numerator 164 * @d: the power of 2 denominator. 165 * 166 * note: this rounds towards 0. 167 */ 168static inline s64 fast_divpow2(s64 n, u8 d) 169{ 170 return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; 171} 172 173/** 174 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 175 * and return it. 176 * @s1: the mean_and_variance to update. 177 * @v1: the new sample. 178 * 179 * see linked pdf equation 12. 180 */ 181static inline void 182mean_and_variance_update(struct mean_and_variance *s, s64 v) 183{ 184 s->n++; 185 s->sum += v; 186 s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); 187} 188 189s64 mean_and_variance_get_mean(struct mean_and_variance s); 190u64 mean_and_variance_get_variance(struct mean_and_variance s1); 191u32 mean_and_variance_get_stddev(struct mean_and_variance s); 192 193void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, 194 s64 v, bool initted, u8 weight); 195 196s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s, 197 u8 weight); 198u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s, 199 u8 weight); 200u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s, 201 u8 weight); 202 203#endif // MEAN_AND_VAIRANCE_H_ 204