1// SPDX-License-Identifier: GPL-2.0-only
2// Copyright (C) 2019-2020 Arm Ltd.
3
4#include <linux/compiler.h>
5#include <linux/kasan-checks.h>
6#include <linux/kernel.h>
7
8#include <net/checksum.h>
9
10static u64 accumulate(u64 sum, u64 data)
11{
12	sum += data;
13	if (sum < data)
14		sum += 1;
15	return sum;
16}
17
18/*
19 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
20 * instrumentation and call kasan explicitly.
21 */
22unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
23{
24	unsigned int offset, shift, sum;
25	const u64 *ptr;
26	u64 data, sum64 = 0;
27
28	if (unlikely(len == 0))
29		return 0;
30
31	offset = (unsigned long)buff & 7;
32	/*
33	 * This is to all intents and purposes safe, since rounding down cannot
34	 * result in a different page or cache line being accessed, and @buff
35	 * should absolutely not be pointing to anything read-sensitive. We do,
36	 * however, have to be careful not to piss off KASAN, which means using
37	 * unchecked reads to accommodate the head and tail, for which we'll
38	 * compensate with an explicit check up-front.
39	 */
40	kasan_check_read(buff, len);
41	ptr = (u64 *)(buff - offset);
42	len = len + offset - 8;
43
44	/*
45	 * Head: zero out any excess leading bytes. Shifting back by the same
46	 * amount should be at least as fast as any other way of handling the
47	 * odd/even alignment, and means we can ignore it until the very end.
48	 */
49	shift = offset * 8;
50	data = *ptr++;
51	data = (data >> shift) << shift;
52
53	/*
54	 * Body: straightforward aligned loads from here on (the paired loads
55	 * underlying the quadword type still only need dword alignment). The
56	 * main loop strictly excludes the tail, so the second loop will always
57	 * run at least once.
58	 */
59	while (unlikely(len > 64)) {
60		__uint128_t tmp1, tmp2, tmp3, tmp4;
61
62		tmp1 = *(__uint128_t *)ptr;
63		tmp2 = *(__uint128_t *)(ptr + 2);
64		tmp3 = *(__uint128_t *)(ptr + 4);
65		tmp4 = *(__uint128_t *)(ptr + 6);
66
67		len -= 64;
68		ptr += 8;
69
70		/* This is the "don't dump the carry flag into a GPR" idiom */
71		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
72		tmp2 += (tmp2 >> 64) | (tmp2 << 64);
73		tmp3 += (tmp3 >> 64) | (tmp3 << 64);
74		tmp4 += (tmp4 >> 64) | (tmp4 << 64);
75		tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
76		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
77		tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
78		tmp3 += (tmp3 >> 64) | (tmp3 << 64);
79		tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
80		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
81		tmp1 = ((tmp1 >> 64) << 64) | sum64;
82		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
83		sum64 = tmp1 >> 64;
84	}
85	while (len > 8) {
86		__uint128_t tmp;
87
88		sum64 = accumulate(sum64, data);
89		tmp = *(__uint128_t *)ptr;
90
91		len -= 16;
92		ptr += 2;
93
94		data = tmp >> 64;
95		sum64 = accumulate(sum64, tmp);
96	}
97	if (len > 0) {
98		sum64 = accumulate(sum64, data);
99		data = *ptr;
100		len -= 8;
101	}
102	/*
103	 * Tail: zero any over-read bytes similarly to the head, again
104	 * preserving odd/even alignment.
105	 */
106	shift = len * -8;
107	data = (data << shift) >> shift;
108	sum64 = accumulate(sum64, data);
109
110	/* Finally, folding */
111	sum64 += (sum64 >> 32) | (sum64 << 32);
112	sum = sum64 >> 32;
113	sum += (sum >> 16) | (sum << 16);
114	if (offset & 1)
115		return (u16)swab32(sum);
116
117	return sum >> 16;
118}
119
120__sum16 csum_ipv6_magic(const struct in6_addr *saddr,
121			const struct in6_addr *daddr,
122			__u32 len, __u8 proto, __wsum csum)
123{
124	__uint128_t src, dst;
125	u64 sum = (__force u64)csum;
126
127	src = *(const __uint128_t *)saddr->s6_addr;
128	dst = *(const __uint128_t *)daddr->s6_addr;
129
130	sum += (__force u32)htonl(len);
131	sum += (u32)proto << 24;
132	src += (src >> 64) | (src << 64);
133	dst += (dst >> 64) | (dst << 64);
134
135	sum = accumulate(sum, src >> 64);
136	sum = accumulate(sum, dst >> 64);
137
138	sum += ((sum >> 32) | (sum << 32));
139	return csum_fold((__force __wsum)(sum >> 32));
140}
141EXPORT_SYMBOL(csum_ipv6_magic);
142