1// SPDX-License-Identifier: GPL-2.0
2#include <linux/kernel.h>
3#include <linux/netfilter.h>
4#include <linux/netfilter_ipv4.h>
5#include <linux/netfilter_ipv6.h>
6#include <net/netfilter/nf_queue.h>
7#include <net/ip6_checksum.h>
8
9#ifdef CONFIG_INET
10__sum16 nf_ip_checksum(struct sk_buff *skb, unsigned int hook,
11		       unsigned int dataoff, u8 protocol)
12{
13	const struct iphdr *iph = ip_hdr(skb);
14	__sum16 csum = 0;
15
16	switch (skb->ip_summed) {
17	case CHECKSUM_COMPLETE:
18		if (hook != NF_INET_PRE_ROUTING && hook != NF_INET_LOCAL_IN)
19			break;
20		if ((protocol != IPPROTO_TCP && protocol != IPPROTO_UDP &&
21		    !csum_fold(skb->csum)) ||
22		    !csum_tcpudp_magic(iph->saddr, iph->daddr,
23				       skb->len - dataoff, protocol,
24				       skb->csum)) {
25			skb->ip_summed = CHECKSUM_UNNECESSARY;
26			break;
27		}
28		fallthrough;
29	case CHECKSUM_NONE:
30		if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP)
31			skb->csum = 0;
32		else
33			skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr,
34						       skb->len - dataoff,
35						       protocol, 0);
36		csum = __skb_checksum_complete(skb);
37	}
38	return csum;
39}
40EXPORT_SYMBOL(nf_ip_checksum);
41#endif
42
43static __sum16 nf_ip_checksum_partial(struct sk_buff *skb, unsigned int hook,
44				      unsigned int dataoff, unsigned int len,
45				      u8 protocol)
46{
47	const struct iphdr *iph = ip_hdr(skb);
48	__sum16 csum = 0;
49
50	switch (skb->ip_summed) {
51	case CHECKSUM_COMPLETE:
52		if (len == skb->len - dataoff)
53			return nf_ip_checksum(skb, hook, dataoff, protocol);
54		fallthrough;
55	case CHECKSUM_NONE:
56		skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, protocol,
57					       skb->len - dataoff, 0);
58		skb->ip_summed = CHECKSUM_NONE;
59		return __skb_checksum_complete_head(skb, dataoff + len);
60	}
61	return csum;
62}
63
64__sum16 nf_ip6_checksum(struct sk_buff *skb, unsigned int hook,
65			unsigned int dataoff, u8 protocol)
66{
67	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
68	__sum16 csum = 0;
69
70	switch (skb->ip_summed) {
71	case CHECKSUM_COMPLETE:
72		if (hook != NF_INET_PRE_ROUTING && hook != NF_INET_LOCAL_IN)
73			break;
74		if (!csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
75				     skb->len - dataoff, protocol,
76				     csum_sub(skb->csum,
77					      skb_checksum(skb, 0,
78							   dataoff, 0)))) {
79			skb->ip_summed = CHECKSUM_UNNECESSARY;
80			break;
81		}
82		fallthrough;
83	case CHECKSUM_NONE:
84		skb->csum = ~csum_unfold(
85				csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
86					     skb->len - dataoff,
87					     protocol,
88					     csum_sub(0,
89						      skb_checksum(skb, 0,
90								   dataoff, 0))));
91		csum = __skb_checksum_complete(skb);
92	}
93	return csum;
94}
95EXPORT_SYMBOL(nf_ip6_checksum);
96
97static __sum16 nf_ip6_checksum_partial(struct sk_buff *skb, unsigned int hook,
98				       unsigned int dataoff, unsigned int len,
99				       u8 protocol)
100{
101	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
102	__wsum hsum;
103	__sum16 csum = 0;
104
105	switch (skb->ip_summed) {
106	case CHECKSUM_COMPLETE:
107		if (len == skb->len - dataoff)
108			return nf_ip6_checksum(skb, hook, dataoff, protocol);
109		fallthrough;
110	case CHECKSUM_NONE:
111		hsum = skb_checksum(skb, 0, dataoff, 0);
112		skb->csum = ~csum_unfold(csum_ipv6_magic(&ip6h->saddr,
113							 &ip6h->daddr,
114							 skb->len - dataoff,
115							 protocol,
116							 csum_sub(0, hsum)));
117		skb->ip_summed = CHECKSUM_NONE;
118		return __skb_checksum_complete_head(skb, dataoff + len);
119	}
120	return csum;
121};
122
123__sum16 nf_checksum(struct sk_buff *skb, unsigned int hook,
124		    unsigned int dataoff, u8 protocol,
125		    unsigned short family)
126{
127	__sum16 csum = 0;
128
129	switch (family) {
130	case AF_INET:
131		csum = nf_ip_checksum(skb, hook, dataoff, protocol);
132		break;
133	case AF_INET6:
134		csum = nf_ip6_checksum(skb, hook, dataoff, protocol);
135		break;
136	}
137
138	return csum;
139}
140EXPORT_SYMBOL_GPL(nf_checksum);
141
142__sum16 nf_checksum_partial(struct sk_buff *skb, unsigned int hook,
143			    unsigned int dataoff, unsigned int len,
144			    u8 protocol, unsigned short family)
145{
146	__sum16 csum = 0;
147
148	switch (family) {
149	case AF_INET:
150		csum = nf_ip_checksum_partial(skb, hook, dataoff, len,
151					      protocol);
152		break;
153	case AF_INET6:
154		csum = nf_ip6_checksum_partial(skb, hook, dataoff, len,
155					       protocol);
156		break;
157	}
158
159	return csum;
160}
161EXPORT_SYMBOL_GPL(nf_checksum_partial);
162
163int nf_route(struct net *net, struct dst_entry **dst, struct flowi *fl,
164	     bool strict, unsigned short family)
165{
166	const struct nf_ipv6_ops *v6ops __maybe_unused;
167	int ret = 0;
168
169	switch (family) {
170	case AF_INET:
171		ret = nf_ip_route(net, dst, fl, strict);
172		break;
173	case AF_INET6:
174		ret = nf_ip6_route(net, dst, fl, strict);
175		break;
176	}
177
178	return ret;
179}
180EXPORT_SYMBOL_GPL(nf_route);
181
182/* Only get and check the lengths, not do any hop-by-hop stuff. */
183int nf_ip6_check_hbh_len(struct sk_buff *skb, u32 *plen)
184{
185	int len, off = sizeof(struct ipv6hdr);
186	unsigned char *nh;
187
188	if (!pskb_may_pull(skb, off + 8))
189		return -ENOMEM;
190	nh = (unsigned char *)(ipv6_hdr(skb) + 1);
191	len = (nh[1] + 1) << 3;
192
193	if (!pskb_may_pull(skb, off + len))
194		return -ENOMEM;
195	nh = skb_network_header(skb);
196
197	off += 2;
198	len -= 2;
199	while (len > 0) {
200		int optlen;
201
202		if (nh[off] == IPV6_TLV_PAD1) {
203			off++;
204			len--;
205			continue;
206		}
207		if (len < 2)
208			return -EBADMSG;
209		optlen = nh[off + 1] + 2;
210		if (optlen > len)
211			return -EBADMSG;
212
213		if (nh[off] == IPV6_TLV_JUMBO) {
214			u32 pkt_len;
215
216			if (nh[off + 1] != 4 || (off & 3) != 2)
217				return -EBADMSG;
218			pkt_len = ntohl(*(__be32 *)(nh + off + 2));
219			if (pkt_len <= IPV6_MAXPLEN ||
220			    ipv6_hdr(skb)->payload_len)
221				return -EBADMSG;
222			if (pkt_len > skb->len - sizeof(struct ipv6hdr))
223				return -EBADMSG;
224			*plen = pkt_len;
225		}
226		off += optlen;
227		len -= optlen;
228	}
229
230	return len ? -EBADMSG : 0;
231}
232EXPORT_SYMBOL_GPL(nf_ip6_check_hbh_len);
233