1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <arpa/inet.h>
6#include <errno.h>
7#include <error.h>
8#include <fcntl.h>
9#include <limits.h>
10#include <linux/filter.h>
11#include <linux/bpf.h>
12#include <linux/if_packet.h>
13#include <linux/if_vlan.h>
14#include <linux/virtio_net.h>
15#include <net/if.h>
16#include <net/ethernet.h>
17#include <netinet/ip.h>
18#include <netinet/udp.h>
19#include <poll.h>
20#include <sched.h>
21#include <stdbool.h>
22#include <stdint.h>
23#include <stdio.h>
24#include <stdlib.h>
25#include <string.h>
26#include <sys/mman.h>
27#include <sys/socket.h>
28#include <sys/stat.h>
29#include <sys/types.h>
30#include <unistd.h>
31
32#include "psock_lib.h"
33
34static bool	cfg_use_bind;
35static bool	cfg_use_csum_off;
36static bool	cfg_use_csum_off_bad;
37static bool	cfg_use_dgram;
38static bool	cfg_use_gso;
39static bool	cfg_use_qdisc_bypass;
40static bool	cfg_use_vlan;
41static bool	cfg_use_vnet;
42
43static char	*cfg_ifname = "lo";
44static int	cfg_mtu	= 1500;
45static int	cfg_payload_len = DATA_LEN;
46static int	cfg_truncate_len = INT_MAX;
47static uint16_t	cfg_port = 8000;
48
49/* test sending up to max mtu + 1 */
50#define TEST_SZ	(sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
51
52static char tbuf[TEST_SZ], rbuf[TEST_SZ];
53
54static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
55{
56	unsigned long sum = 0;
57	int i;
58
59	for (i = 0; i < num_u16; i++)
60		sum += start[i];
61
62	return sum;
63}
64
65static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
66			      unsigned long sum)
67{
68	sum += add_csum_hword(start, num_u16);
69
70	while (sum >> 16)
71		sum = (sum & 0xffff) + (sum >> 16);
72
73	return ~sum;
74}
75
76static int build_vnet_header(void *header)
77{
78	struct virtio_net_hdr *vh = header;
79
80	vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
81
82	if (cfg_use_csum_off) {
83		vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
84		vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
85		vh->csum_offset = __builtin_offsetof(struct udphdr, check);
86
87		/* position check field exactly one byte beyond end of packet */
88		if (cfg_use_csum_off_bad)
89			vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
90					  vh->csum_offset - 1;
91	}
92
93	if (cfg_use_gso) {
94		vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
95		vh->gso_size = cfg_mtu - sizeof(struct iphdr);
96	}
97
98	return sizeof(*vh);
99}
100
101static int build_eth_header(void *header)
102{
103	struct ethhdr *eth = header;
104
105	if (cfg_use_vlan) {
106		uint16_t *tag = header + ETH_HLEN;
107
108		eth->h_proto = htons(ETH_P_8021Q);
109		tag[1] = htons(ETH_P_IP);
110		return ETH_HLEN + 4;
111	}
112
113	eth->h_proto = htons(ETH_P_IP);
114	return ETH_HLEN;
115}
116
117static int build_ipv4_header(void *header, int payload_len)
118{
119	struct iphdr *iph = header;
120
121	iph->ihl = 5;
122	iph->version = 4;
123	iph->ttl = 8;
124	iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
125	iph->id = htons(1337);
126	iph->protocol = IPPROTO_UDP;
127	iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
128	iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
129	iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
130
131	return iph->ihl << 2;
132}
133
134static int build_udp_header(void *header, int payload_len)
135{
136	const int alen = sizeof(uint32_t);
137	struct udphdr *udph = header;
138	int len = sizeof(*udph) + payload_len;
139
140	udph->source = htons(9);
141	udph->dest = htons(cfg_port);
142	udph->len = htons(len);
143
144	if (cfg_use_csum_off)
145		udph->check = build_ip_csum(header - (2 * alen), alen,
146					    htons(IPPROTO_UDP) + udph->len);
147	else
148		udph->check = 0;
149
150	return sizeof(*udph);
151}
152
153static int build_packet(int payload_len)
154{
155	int off = 0;
156
157	off += build_vnet_header(tbuf);
158	off += build_eth_header(tbuf + off);
159	off += build_ipv4_header(tbuf + off, payload_len);
160	off += build_udp_header(tbuf + off, payload_len);
161
162	if (off + payload_len > sizeof(tbuf))
163		error(1, 0, "payload length exceeds max");
164
165	memset(tbuf + off, DATA_CHAR, payload_len);
166
167	return off + payload_len;
168}
169
170static void do_bind(int fd)
171{
172	struct sockaddr_ll laddr = {0};
173
174	laddr.sll_family = AF_PACKET;
175	laddr.sll_protocol = htons(ETH_P_IP);
176	laddr.sll_ifindex = if_nametoindex(cfg_ifname);
177	if (!laddr.sll_ifindex)
178		error(1, errno, "if_nametoindex");
179
180	if (bind(fd, (void *)&laddr, sizeof(laddr)))
181		error(1, errno, "bind");
182}
183
184static void do_send(int fd, char *buf, int len)
185{
186	int ret;
187
188	if (!cfg_use_vnet) {
189		buf += sizeof(struct virtio_net_hdr);
190		len -= sizeof(struct virtio_net_hdr);
191	}
192	if (cfg_use_dgram) {
193		buf += ETH_HLEN;
194		len -= ETH_HLEN;
195	}
196
197	if (cfg_use_bind) {
198		ret = write(fd, buf, len);
199	} else {
200		struct sockaddr_ll laddr = {0};
201
202		laddr.sll_protocol = htons(ETH_P_IP);
203		laddr.sll_ifindex = if_nametoindex(cfg_ifname);
204		if (!laddr.sll_ifindex)
205			error(1, errno, "if_nametoindex");
206
207		ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
208	}
209
210	if (ret == -1)
211		error(1, errno, "write");
212	if (ret != len)
213		error(1, 0, "write: %u %u", ret, len);
214
215	fprintf(stderr, "tx: %u\n", ret);
216}
217
218static int do_tx(void)
219{
220	const int one = 1;
221	int fd, len;
222
223	fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
224	if (fd == -1)
225		error(1, errno, "socket t");
226
227	if (cfg_use_bind)
228		do_bind(fd);
229
230	if (cfg_use_qdisc_bypass &&
231	    setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
232		error(1, errno, "setsockopt qdisc bypass");
233
234	if (cfg_use_vnet &&
235	    setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
236		error(1, errno, "setsockopt vnet");
237
238	len = build_packet(cfg_payload_len);
239
240	if (cfg_truncate_len < len)
241		len = cfg_truncate_len;
242
243	do_send(fd, tbuf, len);
244
245	if (close(fd))
246		error(1, errno, "close t");
247
248	return len;
249}
250
251static int setup_rx(void)
252{
253	struct timeval tv = { .tv_usec = 100 * 1000 };
254	struct sockaddr_in raddr = {0};
255	int fd;
256
257	fd = socket(PF_INET, SOCK_DGRAM, 0);
258	if (fd == -1)
259		error(1, errno, "socket r");
260
261	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
262		error(1, errno, "setsockopt rcv timeout");
263
264	raddr.sin_family = AF_INET;
265	raddr.sin_port = htons(cfg_port);
266	raddr.sin_addr.s_addr = htonl(INADDR_ANY);
267
268	if (bind(fd, (void *)&raddr, sizeof(raddr)))
269		error(1, errno, "bind r");
270
271	return fd;
272}
273
274static void do_rx(int fd, int expected_len, char *expected)
275{
276	int ret;
277
278	ret = recv(fd, rbuf, sizeof(rbuf), 0);
279	if (ret == -1)
280		error(1, errno, "recv");
281	if (ret != expected_len)
282		error(1, 0, "recv: %u != %u", ret, expected_len);
283
284	if (memcmp(rbuf, expected, ret))
285		error(1, 0, "recv: data mismatch");
286
287	fprintf(stderr, "rx: %u\n", ret);
288}
289
290static int setup_sniffer(void)
291{
292	struct timeval tv = { .tv_usec = 100 * 1000 };
293	int fd;
294
295	fd = socket(PF_PACKET, SOCK_RAW, 0);
296	if (fd == -1)
297		error(1, errno, "socket p");
298
299	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
300		error(1, errno, "setsockopt rcv timeout");
301
302	pair_udp_setfilter(fd);
303	do_bind(fd);
304
305	return fd;
306}
307
308static void parse_opts(int argc, char **argv)
309{
310	int c;
311
312	while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
313		switch (c) {
314		case 'b':
315			cfg_use_bind = true;
316			break;
317		case 'c':
318			cfg_use_csum_off = true;
319			break;
320		case 'C':
321			cfg_use_csum_off_bad = true;
322			break;
323		case 'd':
324			cfg_use_dgram = true;
325			break;
326		case 'g':
327			cfg_use_gso = true;
328			break;
329		case 'l':
330			cfg_payload_len = strtoul(optarg, NULL, 0);
331			break;
332		case 'q':
333			cfg_use_qdisc_bypass = true;
334			break;
335		case 't':
336			cfg_truncate_len = strtoul(optarg, NULL, 0);
337			break;
338		case 'v':
339			cfg_use_vnet = true;
340			break;
341		case 'V':
342			cfg_use_vlan = true;
343			break;
344		default:
345			error(1, 0, "%s: parse error", argv[0]);
346		}
347	}
348
349	if (cfg_use_vlan && cfg_use_dgram)
350		error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
351
352	if (cfg_use_csum_off && !cfg_use_vnet)
353		error(1, 0, "option csum offload (-c) requires vnet (-v)");
354
355	if (cfg_use_csum_off_bad && !cfg_use_csum_off)
356		error(1, 0, "option csum bad (-C) requires csum offload (-c)");
357
358	if (cfg_use_gso && !cfg_use_csum_off)
359		error(1, 0, "option gso (-g) requires csum offload (-c)");
360}
361
362static void run_test(void)
363{
364	int fdr, fds, total_len;
365
366	fdr = setup_rx();
367	fds = setup_sniffer();
368
369	total_len = do_tx();
370
371	/* BPF filter accepts only this length, vlan changes MAC */
372	if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
373		do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
374		      tbuf + sizeof(struct virtio_net_hdr));
375
376	do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
377
378	if (close(fds))
379		error(1, errno, "close s");
380	if (close(fdr))
381		error(1, errno, "close r");
382}
383
384int main(int argc, char **argv)
385{
386	parse_opts(argc, argv);
387
388	if (system("ip link set dev lo mtu 1500"))
389		error(1, errno, "ip link set mtu");
390	if (system("ip addr add dev lo 172.17.0.1/24"))
391		error(1, errno, "ip addr add");
392	if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
393		error(1, errno, "sysctl lo.accept_local");
394
395	run_test();
396
397	fprintf(stderr, "OK\n\n");
398	return 0;
399}
400