1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <errno.h>
6#include <fcntl.h>
7#include <stdio.h>
8#include <stdlib.h>
9#include <string.h>
10#include <unistd.h>
11#include <net/if.h>
12#include <linux/if_tun.h>
13#include <linux/netlink.h>
14#include <linux/rtnetlink.h>
15#include <sys/ioctl.h>
16#include <sys/socket.h>
17#include <linux/virtio_net.h>
18#include <netinet/ip.h>
19#include <netinet/udp.h>
20#include "../kselftest_harness.h"
21
22static const char param_dev_tap_name[] = "xmacvtap0";
23static const char param_dev_dummy_name[] = "xdummy0";
24static unsigned char param_hwaddr_src[] = { 0x00, 0xfe, 0x98, 0x14, 0x22, 0x42 };
25static unsigned char param_hwaddr_dest[] = {
26	0x00, 0xfe, 0x98, 0x94, 0xd2, 0x43
27};
28
29#define MAX_RTNL_PAYLOAD (2048)
30#define PKT_DATA 0xCB
31#define TEST_PACKET_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU)
32
33static struct rtattr *rtattr_add(struct nlmsghdr *nh, unsigned short type,
34				 unsigned short len)
35{
36	struct rtattr *rta =
37		(struct rtattr *)((uint8_t *)nh + RTA_ALIGN(nh->nlmsg_len));
38	rta->rta_type = type;
39	rta->rta_len = RTA_LENGTH(len);
40	nh->nlmsg_len = RTA_ALIGN(nh->nlmsg_len) + RTA_ALIGN(rta->rta_len);
41	return rta;
42}
43
44static struct rtattr *rtattr_begin(struct nlmsghdr *nh, unsigned short type)
45{
46	return rtattr_add(nh, type, 0);
47}
48
49static void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
50{
51	uint8_t *end = (uint8_t *)nh + nh->nlmsg_len;
52
53	attr->rta_len = end - (uint8_t *)attr;
54}
55
56static struct rtattr *rtattr_add_str(struct nlmsghdr *nh, unsigned short type,
57				     const char *s)
58{
59	struct rtattr *rta = rtattr_add(nh, type, strlen(s));
60
61	memcpy(RTA_DATA(rta), s, strlen(s));
62	return rta;
63}
64
65static struct rtattr *rtattr_add_strsz(struct nlmsghdr *nh, unsigned short type,
66				       const char *s)
67{
68	struct rtattr *rta = rtattr_add(nh, type, strlen(s) + 1);
69
70	strcpy(RTA_DATA(rta), s);
71	return rta;
72}
73
74static struct rtattr *rtattr_add_any(struct nlmsghdr *nh, unsigned short type,
75				     const void *arr, size_t len)
76{
77	struct rtattr *rta = rtattr_add(nh, type, len);
78
79	memcpy(RTA_DATA(rta), arr, len);
80	return rta;
81}
82
83static int dev_create(const char *dev, const char *link_type,
84		      int (*fill_rtattr)(struct nlmsghdr *nh),
85		      int (*fill_info_data)(struct nlmsghdr *nh))
86{
87	struct {
88		struct nlmsghdr nh;
89		struct ifinfomsg info;
90		unsigned char data[MAX_RTNL_PAYLOAD];
91	} req;
92	struct rtattr *link_info, *info_data;
93	int ret, rtnl;
94
95	rtnl = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE);
96	if (rtnl < 0) {
97		fprintf(stderr, "%s: socket %s\n", __func__, strerror(errno));
98		return 1;
99	}
100
101	memset(&req, 0, sizeof(req));
102	req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
103	req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE;
104	req.nh.nlmsg_type = RTM_NEWLINK;
105
106	req.info.ifi_family = AF_UNSPEC;
107	req.info.ifi_type = 1;
108	req.info.ifi_index = 0;
109	req.info.ifi_flags = IFF_BROADCAST | IFF_UP;
110	req.info.ifi_change = 0xffffffff;
111
112	rtattr_add_str(&req.nh, IFLA_IFNAME, dev);
113
114	if (fill_rtattr) {
115		ret = fill_rtattr(&req.nh);
116		if (ret)
117			return ret;
118	}
119
120	link_info = rtattr_begin(&req.nh, IFLA_LINKINFO);
121
122	rtattr_add_strsz(&req.nh, IFLA_INFO_KIND, link_type);
123
124	if (fill_info_data) {
125		info_data = rtattr_begin(&req.nh, IFLA_INFO_DATA);
126		ret = fill_info_data(&req.nh);
127		if (ret)
128			return ret;
129		rtattr_end(&req.nh, info_data);
130	}
131
132	rtattr_end(&req.nh, link_info);
133
134	ret = send(rtnl, &req, req.nh.nlmsg_len, 0);
135	if (ret < 0)
136		fprintf(stderr, "%s: send %s\n", __func__, strerror(errno));
137	ret = (unsigned int)ret != req.nh.nlmsg_len;
138
139	close(rtnl);
140	return ret;
141}
142
143static int dev_delete(const char *dev)
144{
145	struct {
146		struct nlmsghdr nh;
147		struct ifinfomsg info;
148		unsigned char data[MAX_RTNL_PAYLOAD];
149	} req;
150	int ret, rtnl;
151
152	rtnl = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE);
153	if (rtnl < 0) {
154		fprintf(stderr, "%s: socket %s\n", __func__, strerror(errno));
155		return 1;
156	}
157
158	memset(&req, 0, sizeof(req));
159	req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
160	req.nh.nlmsg_flags = NLM_F_REQUEST;
161	req.nh.nlmsg_type = RTM_DELLINK;
162
163	req.info.ifi_family = AF_UNSPEC;
164
165	rtattr_add_str(&req.nh, IFLA_IFNAME, dev);
166
167	ret = send(rtnl, &req, req.nh.nlmsg_len, 0);
168	if (ret < 0)
169		fprintf(stderr, "%s: send %s\n", __func__, strerror(errno));
170
171	ret = (unsigned int)ret != req.nh.nlmsg_len;
172
173	close(rtnl);
174	return ret;
175}
176
177static int macvtap_fill_rtattr(struct nlmsghdr *nh)
178{
179	int ifindex;
180
181	ifindex = if_nametoindex(param_dev_dummy_name);
182	if (ifindex == 0) {
183		fprintf(stderr, "%s: ifindex  %s\n", __func__, strerror(errno));
184		return -errno;
185	}
186
187	rtattr_add_any(nh, IFLA_LINK, &ifindex, sizeof(ifindex));
188	rtattr_add_any(nh, IFLA_ADDRESS, param_hwaddr_src, ETH_ALEN);
189
190	return 0;
191}
192
193static int opentap(const char *devname)
194{
195	int ifindex;
196	char buf[256];
197	int fd;
198	struct ifreq ifr;
199
200	ifindex = if_nametoindex(devname);
201	if (ifindex == 0) {
202		fprintf(stderr, "%s: ifindex %s\n", __func__, strerror(errno));
203		return -errno;
204	}
205
206	sprintf(buf, "/dev/tap%d", ifindex);
207	fd = open(buf, O_RDWR | O_NONBLOCK);
208	if (fd < 0) {
209		fprintf(stderr, "%s: open %s\n", __func__, strerror(errno));
210		return -errno;
211	}
212
213	memset(&ifr, 0, sizeof(ifr));
214	strcpy(ifr.ifr_name, devname);
215	ifr.ifr_flags = IFF_TAP | IFF_NO_PI | IFF_VNET_HDR | IFF_MULTI_QUEUE;
216	if (ioctl(fd, TUNSETIFF, &ifr, sizeof(ifr)) < 0)
217		return -errno;
218	return fd;
219}
220
221size_t build_eth(uint8_t *buf, uint16_t proto)
222{
223	struct ethhdr *eth = (struct ethhdr *)buf;
224
225	eth->h_proto = htons(proto);
226	memcpy(eth->h_source, param_hwaddr_src, ETH_ALEN);
227	memcpy(eth->h_dest, param_hwaddr_dest, ETH_ALEN);
228
229	return ETH_HLEN;
230}
231
232static uint32_t add_csum(const uint8_t *buf, int len)
233{
234	uint32_t sum = 0;
235	uint16_t *sbuf = (uint16_t *)buf;
236
237	while (len > 1) {
238		sum += *sbuf++;
239		len -= 2;
240	}
241
242	if (len)
243		sum += *(uint8_t *)sbuf;
244
245	return sum;
246}
247
248static uint16_t finish_ip_csum(uint32_t sum)
249{
250	uint16_t lo = sum & 0xffff;
251	uint16_t hi = sum >> 16;
252
253	return ~(lo + hi);
254
255}
256
257static uint16_t build_ip_csum(const uint8_t *buf, int len,
258			      uint32_t sum)
259{
260	sum += add_csum(buf, len);
261	return finish_ip_csum(sum);
262}
263
264static int build_ipv4_header(uint8_t *buf, int payload_len)
265{
266	struct iphdr *iph = (struct iphdr *)buf;
267
268	iph->ihl = 5;
269	iph->version = 4;
270	iph->ttl = 8;
271	iph->tot_len =
272		htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
273	iph->id = htons(1337);
274	iph->protocol = IPPROTO_UDP;
275	iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
276	iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
277	iph->check = build_ip_csum(buf, iph->ihl << 2, 0);
278
279	return iph->ihl << 2;
280}
281
282static int build_udp_packet(uint8_t *buf, int payload_len, bool csum_off)
283{
284	const int ip4alen = sizeof(uint32_t);
285	struct udphdr *udph = (struct udphdr *)buf;
286	int len = sizeof(*udph) + payload_len;
287	uint32_t sum = 0;
288
289	udph->source = htons(22);
290	udph->dest = htons(58822);
291	udph->len = htons(len);
292
293	memset(buf + sizeof(struct udphdr), PKT_DATA, payload_len);
294
295	sum = add_csum(buf - 2 * ip4alen, 2 * ip4alen);
296	sum += htons(IPPROTO_UDP) + udph->len;
297
298	if (!csum_off)
299		sum += add_csum(buf, len);
300
301	udph->check = finish_ip_csum(sum);
302
303	return sizeof(*udph) + payload_len;
304}
305
306size_t build_test_packet_valid_udp_gso(uint8_t *buf, size_t payload_len)
307{
308	uint8_t *cur = buf;
309	struct virtio_net_hdr *vh = (struct virtio_net_hdr *)buf;
310
311	vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
312	vh->flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
313	vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
314	vh->csum_offset = __builtin_offsetof(struct udphdr, check);
315	vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
316	vh->gso_size = ETH_DATA_LEN - sizeof(struct iphdr);
317	cur += sizeof(*vh);
318
319	cur += build_eth(cur, ETH_P_IP);
320	cur += build_ipv4_header(cur, payload_len);
321	cur += build_udp_packet(cur, payload_len, true);
322
323	return cur - buf;
324}
325
326size_t build_test_packet_valid_udp_csum(uint8_t *buf, size_t payload_len)
327{
328	uint8_t *cur = buf;
329	struct virtio_net_hdr *vh = (struct virtio_net_hdr *)buf;
330
331	vh->flags = VIRTIO_NET_HDR_F_DATA_VALID;
332	vh->gso_type = VIRTIO_NET_HDR_GSO_NONE;
333	cur += sizeof(*vh);
334
335	cur += build_eth(cur, ETH_P_IP);
336	cur += build_ipv4_header(cur, payload_len);
337	cur += build_udp_packet(cur, payload_len, false);
338
339	return cur - buf;
340}
341
342size_t build_test_packet_crash_tap_invalid_eth_proto(uint8_t *buf,
343						     size_t payload_len)
344{
345	uint8_t *cur = buf;
346	struct virtio_net_hdr *vh = (struct virtio_net_hdr *)buf;
347
348	vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
349	vh->flags = 0;
350	vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
351	vh->gso_size = ETH_DATA_LEN - sizeof(struct iphdr);
352	cur += sizeof(*vh);
353
354	cur += build_eth(cur, 0);
355	cur += sizeof(struct iphdr) + sizeof(struct udphdr);
356	cur += build_ipv4_header(cur, payload_len);
357	cur += build_udp_packet(cur, payload_len, true);
358	cur += payload_len;
359
360	return cur - buf;
361}
362
363FIXTURE(tap)
364{
365	int fd;
366};
367
368FIXTURE_SETUP(tap)
369{
370	int ret;
371
372	ret = dev_create(param_dev_dummy_name, "dummy", NULL, NULL);
373	EXPECT_EQ(ret, 0);
374
375	ret = dev_create(param_dev_tap_name, "macvtap", macvtap_fill_rtattr,
376			 NULL);
377	EXPECT_EQ(ret, 0);
378
379	self->fd = opentap(param_dev_tap_name);
380	ASSERT_GE(self->fd, 0);
381}
382
383FIXTURE_TEARDOWN(tap)
384{
385	int ret;
386
387	if (self->fd != -1)
388		close(self->fd);
389
390	ret = dev_delete(param_dev_tap_name);
391	EXPECT_EQ(ret, 0);
392
393	ret = dev_delete(param_dev_dummy_name);
394	EXPECT_EQ(ret, 0);
395}
396
397TEST_F(tap, test_packet_valid_udp_gso)
398{
399	uint8_t pkt[TEST_PACKET_SZ];
400	size_t off;
401	int ret;
402
403	memset(pkt, 0, sizeof(pkt));
404	off = build_test_packet_valid_udp_gso(pkt, 1021);
405	ret = write(self->fd, pkt, off);
406	ASSERT_EQ(ret, off);
407}
408
409TEST_F(tap, test_packet_valid_udp_csum)
410{
411	uint8_t pkt[TEST_PACKET_SZ];
412	size_t off;
413	int ret;
414
415	memset(pkt, 0, sizeof(pkt));
416	off = build_test_packet_valid_udp_csum(pkt, 1024);
417	ret = write(self->fd, pkt, off);
418	ASSERT_EQ(ret, off);
419}
420
421TEST_F(tap, test_packet_crash_tap_invalid_eth_proto)
422{
423	uint8_t pkt[TEST_PACKET_SZ];
424	size_t off;
425	int ret;
426
427	memset(pkt, 0, sizeof(pkt));
428	off = build_test_packet_crash_tap_invalid_eth_proto(pkt, 1024);
429	ret = write(self->fd, pkt, off);
430	ASSERT_EQ(ret, -1);
431	ASSERT_EQ(errno, EINVAL);
432}
433
434TEST_HARNESS_MAIN
435