1// SPDX-License-Identifier: GPL-2.0
2/*
3 * ipsec.c - Check xfrm on veth inside a net-ns.
4 * Copyright (c) 2018 Dmitry Safonov
5 */
6
7#define _GNU_SOURCE
8
9#include <arpa/inet.h>
10#include <asm/types.h>
11#include <errno.h>
12#include <fcntl.h>
13#include <limits.h>
14#include <linux/limits.h>
15#include <linux/netlink.h>
16#include <linux/random.h>
17#include <linux/rtnetlink.h>
18#include <linux/veth.h>
19#include <linux/xfrm.h>
20#include <netinet/in.h>
21#include <net/if.h>
22#include <sched.h>
23#include <stdbool.h>
24#include <stdint.h>
25#include <stdio.h>
26#include <stdlib.h>
27#include <string.h>
28#include <sys/mman.h>
29#include <sys/socket.h>
30#include <sys/stat.h>
31#include <sys/syscall.h>
32#include <sys/types.h>
33#include <sys/wait.h>
34#include <time.h>
35#include <unistd.h>
36
37#include "../kselftest.h"
38
39#define printk(fmt, ...)						\
40	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41
42#define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43
44#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
45
46#define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
47#define MAX_PAYLOAD	2048
48#define XFRM_ALGO_KEY_BUF_SIZE	512
49#define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
50#define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
51#define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
52
53/* /30 mask for one veth connection */
54#define PREFIX_LEN	30
55#define child_ip(nr)	(4*nr + 1)
56#define grchild_ip(nr)	(4*nr + 2)
57
58#define VETH_FMT	"ktst-%d"
59#define VETH_LEN	12
60
61#define XFRM_ALGO_NR_KEYS 29
62
63static int nsfd_parent	= -1;
64static int nsfd_childa	= -1;
65static int nsfd_childb	= -1;
66static long page_size;
67
68/*
69 * ksft_cnt is static in kselftest, so isn't shared with children.
70 * We have to send a test result back to parent and count there.
71 * results_fd is a pipe with test feedback from children.
72 */
73static int results_fd[2];
74
75const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
76const unsigned int ping_timeout		= 300;
77const unsigned int ping_count		= 100;
78const unsigned int ping_success		= 80;
79
80struct xfrm_key_entry {
81	char algo_name[35];
82	int key_len;
83};
84
85struct xfrm_key_entry xfrm_key_entries[] = {
86	{"digest_null", 0},
87	{"ecb(cipher_null)", 0},
88	{"cbc(des)", 64},
89	{"hmac(md5)", 128},
90	{"cmac(aes)", 128},
91	{"xcbc(aes)", 128},
92	{"cbc(cast5)", 128},
93	{"cbc(serpent)", 128},
94	{"hmac(sha1)", 160},
95	{"hmac(rmd160)", 160},
96	{"cbc(des3_ede)", 192},
97	{"hmac(sha256)", 256},
98	{"cbc(aes)", 256},
99	{"cbc(camellia)", 256},
100	{"cbc(twofish)", 256},
101	{"rfc3686(ctr(aes))", 288},
102	{"hmac(sha384)", 384},
103	{"cbc(blowfish)", 448},
104	{"hmac(sha512)", 512},
105	{"rfc4106(gcm(aes))-128", 160},
106	{"rfc4543(gcm(aes))-128", 160},
107	{"rfc4309(ccm(aes))-128", 152},
108	{"rfc4106(gcm(aes))-192", 224},
109	{"rfc4543(gcm(aes))-192", 224},
110	{"rfc4309(ccm(aes))-192", 216},
111	{"rfc4106(gcm(aes))-256", 288},
112	{"rfc4543(gcm(aes))-256", 288},
113	{"rfc4309(ccm(aes))-256", 280},
114	{"rfc7539(chacha20,poly1305)-128", 0}
115};
116
117static void randomize_buffer(void *buf, size_t buflen)
118{
119	int *p = (int *)buf;
120	size_t words = buflen / sizeof(int);
121	size_t leftover = buflen % sizeof(int);
122
123	if (!buflen)
124		return;
125
126	while (words--)
127		*p++ = rand();
128
129	if (leftover) {
130		int tmp = rand();
131
132		memcpy(buf + buflen - leftover, &tmp, leftover);
133	}
134
135	return;
136}
137
138static int unshare_open(void)
139{
140	const char *netns_path = "/proc/self/ns/net";
141	int fd;
142
143	if (unshare(CLONE_NEWNET) != 0) {
144		pr_err("unshare()");
145		return -1;
146	}
147
148	fd = open(netns_path, O_RDONLY);
149	if (fd <= 0) {
150		pr_err("open(%s)", netns_path);
151		return -1;
152	}
153
154	return fd;
155}
156
157static int switch_ns(int fd)
158{
159	if (setns(fd, CLONE_NEWNET)) {
160		pr_err("setns()");
161		return -1;
162	}
163	return 0;
164}
165
166/*
167 * Running the test inside a new parent net namespace to bother less
168 * about cleanup on error-path.
169 */
170static int init_namespaces(void)
171{
172	nsfd_parent = unshare_open();
173	if (nsfd_parent <= 0)
174		return -1;
175
176	nsfd_childa = unshare_open();
177	if (nsfd_childa <= 0)
178		return -1;
179
180	if (switch_ns(nsfd_parent))
181		return -1;
182
183	nsfd_childb = unshare_open();
184	if (nsfd_childb <= 0)
185		return -1;
186
187	if (switch_ns(nsfd_parent))
188		return -1;
189	return 0;
190}
191
192static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
193{
194	if (*sock > 0) {
195		seq_nr++;
196		return 0;
197	}
198
199	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
200	if (*sock <= 0) {
201		pr_err("socket(AF_NETLINK)");
202		return -1;
203	}
204
205	randomize_buffer(seq_nr, sizeof(*seq_nr));
206
207	return 0;
208}
209
210static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
211{
212	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
213}
214
215static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
216		unsigned short rta_type, const void *payload, size_t size)
217{
218	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
219	struct rtattr *attr = rtattr_hdr(nh);
220	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
221
222	if (req_sz < nl_size) {
223		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
224		return -1;
225	}
226	nh->nlmsg_len = nl_size;
227
228	attr->rta_len = RTA_LENGTH(size);
229	attr->rta_type = rta_type;
230	memcpy(RTA_DATA(attr), payload, size);
231
232	return 0;
233}
234
235static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
236		unsigned short rta_type, const void *payload, size_t size)
237{
238	struct rtattr *ret = rtattr_hdr(nh);
239
240	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
241		return 0;
242
243	return ret;
244}
245
246static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
247		unsigned short rta_type)
248{
249	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
250}
251
252static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
253{
254	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
255
256	attr->rta_len = nlmsg_end - (char *)attr;
257}
258
259static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
260		const char *peer, int ns)
261{
262	struct ifinfomsg pi;
263	struct rtattr *peer_attr;
264
265	memset(&pi, 0, sizeof(pi));
266	pi.ifi_family	= AF_UNSPEC;
267	pi.ifi_change	= 0xFFFFFFFF;
268
269	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
270	if (!peer_attr)
271		return -1;
272
273	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
274		return -1;
275
276	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
277		return -1;
278
279	rtattr_end(nh, peer_attr);
280
281	return 0;
282}
283
284static int netlink_check_answer(int sock)
285{
286	struct nlmsgerror {
287		struct nlmsghdr hdr;
288		int error;
289		struct nlmsghdr orig_msg;
290	} answer;
291
292	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
293		pr_err("recv()");
294		return -1;
295	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
296		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
297		return -1;
298	} else if (answer.error) {
299		printk("NLMSG_ERROR: %d: %s",
300			answer.error, strerror(-answer.error));
301		return answer.error;
302	}
303
304	return 0;
305}
306
307static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
308		const char *peerb, int ns_b)
309{
310	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
311	struct {
312		struct nlmsghdr		nh;
313		struct ifinfomsg	info;
314		char			attrbuf[MAX_PAYLOAD];
315	} req;
316	const char veth_type[] = "veth";
317	struct rtattr *link_info, *info_data;
318
319	memset(&req, 0, sizeof(req));
320	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
321	req.nh.nlmsg_type	= RTM_NEWLINK;
322	req.nh.nlmsg_flags	= flags;
323	req.nh.nlmsg_seq	= seq;
324	req.info.ifi_family	= AF_UNSPEC;
325	req.info.ifi_change	= 0xFFFFFFFF;
326
327	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
328		return -1;
329
330	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
331		return -1;
332
333	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
334	if (!link_info)
335		return -1;
336
337	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
338		return -1;
339
340	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
341	if (!info_data)
342		return -1;
343
344	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
345		return -1;
346
347	rtattr_end(&req.nh, info_data);
348	rtattr_end(&req.nh, link_info);
349
350	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
351		pr_err("send()");
352		return -1;
353	}
354	return netlink_check_answer(sock);
355}
356
357static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
358		struct in_addr addr, uint8_t prefix)
359{
360	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
361	struct {
362		struct nlmsghdr		nh;
363		struct ifaddrmsg	info;
364		char			attrbuf[MAX_PAYLOAD];
365	} req;
366
367	memset(&req, 0, sizeof(req));
368	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
369	req.nh.nlmsg_type	= RTM_NEWADDR;
370	req.nh.nlmsg_flags	= flags;
371	req.nh.nlmsg_seq	= seq;
372	req.info.ifa_family	= AF_INET;
373	req.info.ifa_prefixlen	= prefix;
374	req.info.ifa_index	= if_nametoindex(intf);
375
376#ifdef DEBUG
377	{
378		char addr_str[IPV4_STR_SZ] = {};
379
380		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
381
382		printk("ip addr set %s", addr_str);
383	}
384#endif
385
386	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
387		return -1;
388
389	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
390		return -1;
391
392	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
393		pr_err("send()");
394		return -1;
395	}
396	return netlink_check_answer(sock);
397}
398
399static int link_set_up(int sock, uint32_t seq, const char *intf)
400{
401	struct {
402		struct nlmsghdr		nh;
403		struct ifinfomsg	info;
404		char			attrbuf[MAX_PAYLOAD];
405	} req;
406
407	memset(&req, 0, sizeof(req));
408	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
409	req.nh.nlmsg_type	= RTM_NEWLINK;
410	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
411	req.nh.nlmsg_seq	= seq;
412	req.info.ifi_family	= AF_UNSPEC;
413	req.info.ifi_change	= 0xFFFFFFFF;
414	req.info.ifi_index	= if_nametoindex(intf);
415	req.info.ifi_flags	= IFF_UP;
416	req.info.ifi_change	= IFF_UP;
417
418	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419		pr_err("send()");
420		return -1;
421	}
422	return netlink_check_answer(sock);
423}
424
425static int ip4_route_set(int sock, uint32_t seq, const char *intf,
426		struct in_addr src, struct in_addr dst)
427{
428	struct {
429		struct nlmsghdr	nh;
430		struct rtmsg	rt;
431		char		attrbuf[MAX_PAYLOAD];
432	} req;
433	unsigned int index = if_nametoindex(intf);
434
435	memset(&req, 0, sizeof(req));
436	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
437	req.nh.nlmsg_type	= RTM_NEWROUTE;
438	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
439	req.nh.nlmsg_seq	= seq;
440	req.rt.rtm_family	= AF_INET;
441	req.rt.rtm_dst_len	= 32;
442	req.rt.rtm_table	= RT_TABLE_MAIN;
443	req.rt.rtm_protocol	= RTPROT_BOOT;
444	req.rt.rtm_scope	= RT_SCOPE_LINK;
445	req.rt.rtm_type		= RTN_UNICAST;
446
447	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
448		return -1;
449
450	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
451		return -1;
452
453	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
454		return -1;
455
456	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
457		pr_err("send()");
458		return -1;
459	}
460
461	return netlink_check_answer(sock);
462}
463
464static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
465		struct in_addr tunsrc, struct in_addr tundst)
466{
467	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
468			tunsrc, PREFIX_LEN)) {
469		printk("Failed to set ipv4 addr");
470		return -1;
471	}
472
473	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
474		printk("Failed to set ipv4 route");
475		return -1;
476	}
477
478	return 0;
479}
480
481static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
482{
483	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
484	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
485	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
486	int route_sock = -1, ret = -1;
487	uint32_t route_seq;
488
489	if (switch_ns(nsfd))
490		return -1;
491
492	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
493		printk("Failed to open netlink route socket in child");
494		return -1;
495	}
496
497	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
498		printk("Failed to set ipv4 addr");
499		goto err;
500	}
501
502	if (link_set_up(route_sock, route_seq++, veth)) {
503		printk("Failed to bring up %s", veth);
504		goto err;
505	}
506
507	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
508		printk("Failed to add tunnel route on %s", veth);
509		goto err;
510	}
511	ret = 0;
512
513err:
514	close(route_sock);
515	return ret;
516}
517
518#define ALGO_LEN	64
519enum desc_type {
520	CREATE_TUNNEL	= 0,
521	ALLOCATE_SPI,
522	MONITOR_ACQUIRE,
523	EXPIRE_STATE,
524	EXPIRE_POLICY,
525	SPDINFO_ATTRS,
526};
527const char *desc_name[] = {
528	"create tunnel",
529	"alloc spi",
530	"monitor acquire",
531	"expire state",
532	"expire policy",
533	"spdinfo attributes",
534	""
535};
536struct xfrm_desc {
537	enum desc_type	type;
538	uint8_t		proto;
539	char		a_algo[ALGO_LEN];
540	char		e_algo[ALGO_LEN];
541	char		c_algo[ALGO_LEN];
542	char		ae_algo[ALGO_LEN];
543	unsigned int	icv_len;
544	/* unsigned key_len; */
545};
546
547enum msg_type {
548	MSG_ACK		= 0,
549	MSG_EXIT,
550	MSG_PING,
551	MSG_XFRM_PREPARE,
552	MSG_XFRM_ADD,
553	MSG_XFRM_DEL,
554	MSG_XFRM_CLEANUP,
555};
556
557struct test_desc {
558	enum msg_type type;
559	union {
560		struct {
561			in_addr_t reply_ip;
562			unsigned int port;
563		} ping;
564		struct xfrm_desc xfrm_desc;
565	} body;
566};
567
568struct test_result {
569	struct xfrm_desc desc;
570	unsigned int res;
571};
572
573static void write_test_result(unsigned int res, struct xfrm_desc *d)
574{
575	struct test_result tr = {};
576	ssize_t ret;
577
578	tr.desc = *d;
579	tr.res = res;
580
581	ret = write(results_fd[1], &tr, sizeof(tr));
582	if (ret != sizeof(tr))
583		pr_err("Failed to write the result in pipe %zd", ret);
584}
585
586static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
587{
588	ssize_t bytes = write(fd, msg, sizeof(*msg));
589
590	/* Make sure that write/read is atomic to a pipe */
591	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
592
593	if (bytes < 0) {
594		pr_err("write()");
595		if (exit_of_fail)
596			exit(KSFT_FAIL);
597	}
598	if (bytes != sizeof(*msg)) {
599		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
600		if (exit_of_fail)
601			exit(KSFT_FAIL);
602	}
603}
604
605static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
606{
607	ssize_t bytes = read(fd, msg, sizeof(*msg));
608
609	if (bytes < 0) {
610		pr_err("read()");
611		if (exit_of_fail)
612			exit(KSFT_FAIL);
613	}
614	if (bytes != sizeof(*msg)) {
615		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
616		if (exit_of_fail)
617			exit(KSFT_FAIL);
618	}
619}
620
621static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
622		unsigned int *server_port, int sock[2])
623{
624	struct sockaddr_in server;
625	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
626	socklen_t s_len = sizeof(server);
627
628	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
629	if (sock[0] < 0) {
630		pr_err("socket()");
631		return -1;
632	}
633
634	server.sin_family	= AF_INET;
635	server.sin_port		= 0;
636	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
637
638	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
639		pr_err("bind()");
640		goto err_close_server;
641	}
642
643	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
644		pr_err("getsockname()");
645		goto err_close_server;
646	}
647
648	*server_port = ntohs(server.sin_port);
649
650	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
651		pr_err("setsockopt()");
652		goto err_close_server;
653	}
654
655	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
656	if (sock[1] < 0) {
657		pr_err("socket()");
658		goto err_close_server;
659	}
660
661	return 0;
662
663err_close_server:
664	close(sock[0]);
665	return -1;
666}
667
668static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
669		char *buf, size_t buf_len)
670{
671	struct sockaddr_in server;
672	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
673	char *sock_buf[buf_len];
674	ssize_t r_bytes, s_bytes;
675
676	server.sin_family	= AF_INET;
677	server.sin_port		= htons(port);
678	server.sin_addr.s_addr	= dest_ip;
679
680	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
681	if (s_bytes < 0) {
682		pr_err("sendto()");
683		return -1;
684	} else if (s_bytes != buf_len) {
685		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
686		return -1;
687	}
688
689	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
690	if (r_bytes < 0) {
691		if (errno != EAGAIN)
692			pr_err("recv()");
693		return -1;
694	} else if (r_bytes == 0) { /* EOF */
695		printk("EOF on reply to ping");
696		return -1;
697	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
698		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
699		return -1;
700	}
701
702	return 0;
703}
704
705static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
706		char *buf, size_t buf_len)
707{
708	struct sockaddr_in server;
709	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
710	char *sock_buf[buf_len];
711	ssize_t r_bytes, s_bytes;
712
713	server.sin_family	= AF_INET;
714	server.sin_port		= htons(port);
715	server.sin_addr.s_addr	= dest_ip;
716
717	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
718	if (r_bytes < 0) {
719		if (errno != EAGAIN)
720			pr_err("recv()");
721		return -1;
722	}
723	if (r_bytes == 0) { /* EOF */
724		printk("EOF on reply to ping");
725		return -1;
726	}
727	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
728		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
729		return -1;
730	}
731
732	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
733	if (s_bytes < 0) {
734		pr_err("sendto()");
735		return -1;
736	} else if (s_bytes != buf_len) {
737		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
738		return -1;
739	}
740
741	return 0;
742}
743
744typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
745		char *buf, size_t buf_len);
746static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
747		bool init_side, int d_port, in_addr_t to, ping_f func)
748{
749	struct test_desc msg;
750	unsigned int s_port, i, ping_succeeded = 0;
751	int ping_sock[2];
752	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
753
754	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
755		printk("Failed to init ping");
756		return -1;
757	}
758
759	memset(&msg, 0, sizeof(msg));
760	msg.type		= MSG_PING;
761	msg.body.ping.port	= s_port;
762	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
763
764	write_msg(cmd_fd, &msg, 0);
765	if (init_side) {
766		/* The other end sends ip to ping */
767		read_msg(cmd_fd, &msg, 0);
768		if (msg.type != MSG_PING)
769			return -1;
770		to = msg.body.ping.reply_ip;
771		d_port = msg.body.ping.port;
772	}
773
774	for (i = 0; i < ping_count ; i++) {
775		struct timespec sleep_time = {
776			.tv_sec = 0,
777			.tv_nsec = ping_delay_nsec,
778		};
779
780		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
781		nanosleep(&sleep_time, 0);
782	}
783
784	close(ping_sock[0]);
785	close(ping_sock[1]);
786
787	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
788	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
789
790	if (ping_succeeded < ping_success) {
791		printk("ping (%s) %s->%s failed %u/%u times",
792			init_side ? "send" : "reply", from_str, to_str,
793			ping_count - ping_succeeded, ping_count);
794		return -1;
795	}
796
797#ifdef DEBUG
798	printk("ping (%s) %s->%s succeeded %u/%u times",
799		init_side ? "send" : "reply", from_str, to_str,
800		ping_succeeded, ping_count);
801#endif
802
803	return 0;
804}
805
806static int xfrm_fill_key(char *name, char *buf,
807		size_t buf_len, unsigned int *key_len)
808{
809	int i;
810
811	for (i = 0; i < XFRM_ALGO_NR_KEYS; i++) {
812		if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
813			*key_len = xfrm_key_entries[i].key_len;
814	}
815
816	if (*key_len > buf_len) {
817		printk("Can't pack a key - too big for buffer");
818		return -1;
819	}
820
821	randomize_buffer(buf, *key_len);
822
823	return 0;
824}
825
826static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
827		struct xfrm_desc *desc)
828{
829	struct {
830		union {
831			struct xfrm_algo	alg;
832			struct xfrm_algo_aead	aead;
833			struct xfrm_algo_auth	auth;
834		} u;
835		char buf[XFRM_ALGO_KEY_BUF_SIZE];
836	} alg = {};
837	size_t alen, elen, clen, aelen;
838	unsigned short type;
839
840	alen = strlen(desc->a_algo);
841	elen = strlen(desc->e_algo);
842	clen = strlen(desc->c_algo);
843	aelen = strlen(desc->ae_algo);
844
845	/* Verify desc */
846	switch (desc->proto) {
847	case IPPROTO_AH:
848		if (!alen || elen || clen || aelen) {
849			printk("BUG: buggy ah desc");
850			return -1;
851		}
852		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
853		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
854				sizeof(alg.buf), &alg.u.alg.alg_key_len))
855			return -1;
856		type = XFRMA_ALG_AUTH;
857		break;
858	case IPPROTO_COMP:
859		if (!clen || elen || alen || aelen) {
860			printk("BUG: buggy comp desc");
861			return -1;
862		}
863		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
864		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
865				sizeof(alg.buf), &alg.u.alg.alg_key_len))
866			return -1;
867		type = XFRMA_ALG_COMP;
868		break;
869	case IPPROTO_ESP:
870		if (!((alen && elen) ^ aelen) || clen) {
871			printk("BUG: buggy esp desc");
872			return -1;
873		}
874		if (aelen) {
875			alg.u.aead.alg_icv_len = desc->icv_len;
876			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
877			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
878						sizeof(alg.buf), &alg.u.aead.alg_key_len))
879				return -1;
880			type = XFRMA_ALG_AEAD;
881		} else {
882
883			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
884			type = XFRMA_ALG_CRYPT;
885			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
886						sizeof(alg.buf), &alg.u.alg.alg_key_len))
887				return -1;
888			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
889				return -1;
890
891			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
892			type = XFRMA_ALG_AUTH;
893			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
894						sizeof(alg.buf), &alg.u.alg.alg_key_len))
895				return -1;
896		}
897		break;
898	default:
899		printk("BUG: unknown proto in desc");
900		return -1;
901	}
902
903	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
904		return -1;
905
906	return 0;
907}
908
909static inline uint32_t gen_spi(struct in_addr src)
910{
911	return htonl(inet_lnaof(src));
912}
913
914static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
915		struct in_addr src, struct in_addr dst,
916		struct xfrm_desc *desc)
917{
918	struct {
919		struct nlmsghdr		nh;
920		struct xfrm_usersa_info	info;
921		char			attrbuf[MAX_PAYLOAD];
922	} req;
923
924	memset(&req, 0, sizeof(req));
925	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
926	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
927	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
928	req.nh.nlmsg_seq	= seq;
929
930	/* Fill selector. */
931	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
932	memcpy(&req.info.sel.saddr, &src, sizeof(src));
933	req.info.sel.family		= AF_INET;
934	req.info.sel.prefixlen_d	= PREFIX_LEN;
935	req.info.sel.prefixlen_s	= PREFIX_LEN;
936
937	/* Fill id */
938	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
939	/* Note: zero-spi cannot be deleted */
940	req.info.id.spi = spi;
941	req.info.id.proto	= desc->proto;
942
943	memcpy(&req.info.saddr, &src, sizeof(src));
944
945	/* Fill lifteme_cfg */
946	req.info.lft.soft_byte_limit	= XFRM_INF;
947	req.info.lft.hard_byte_limit	= XFRM_INF;
948	req.info.lft.soft_packet_limit	= XFRM_INF;
949	req.info.lft.hard_packet_limit	= XFRM_INF;
950
951	req.info.family		= AF_INET;
952	req.info.mode		= XFRM_MODE_TUNNEL;
953
954	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
955		return -1;
956
957	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
958		pr_err("send()");
959		return -1;
960	}
961
962	return netlink_check_answer(xfrm_sock);
963}
964
965static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
966		struct in_addr src, struct in_addr dst,
967		struct xfrm_desc *desc)
968{
969	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
970		return false;
971
972	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
973		return false;
974
975	if (info->sel.family != AF_INET					||
976			info->sel.prefixlen_d != PREFIX_LEN		||
977			info->sel.prefixlen_s != PREFIX_LEN)
978		return false;
979
980	if (info->id.spi != spi || info->id.proto != desc->proto)
981		return false;
982
983	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
984		return false;
985
986	if (memcmp(&info->saddr, &src, sizeof(src)))
987		return false;
988
989	if (info->lft.soft_byte_limit != XFRM_INF			||
990			info->lft.hard_byte_limit != XFRM_INF		||
991			info->lft.soft_packet_limit != XFRM_INF		||
992			info->lft.hard_packet_limit != XFRM_INF)
993		return false;
994
995	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
996		return false;
997
998	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
999
1000	return true;
1001}
1002
1003static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1004		struct in_addr src, struct in_addr dst,
1005		struct xfrm_desc *desc)
1006{
1007	struct {
1008		struct nlmsghdr		nh;
1009		char			attrbuf[MAX_PAYLOAD];
1010	} req;
1011	struct {
1012		struct nlmsghdr		nh;
1013		union {
1014			struct xfrm_usersa_info	info;
1015			int error;
1016		};
1017		char			attrbuf[MAX_PAYLOAD];
1018	} answer;
1019	struct xfrm_address_filter filter = {};
1020	bool found = false;
1021
1022
1023	memset(&req, 0, sizeof(req));
1024	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1025	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1026	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1027	req.nh.nlmsg_seq	= seq;
1028
1029	/*
1030	 * Add dump filter by source address as there may be other tunnels
1031	 * in this netns (if tests run in parallel).
1032	 */
1033	filter.family = AF_INET;
1034	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1035	memcpy(&filter.saddr, &src, sizeof(src));
1036	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1037				&filter, sizeof(filter)))
1038		return -1;
1039
1040	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1041		pr_err("send()");
1042		return -1;
1043	}
1044
1045	while (1) {
1046		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1047			pr_err("recv()");
1048			return -1;
1049		}
1050		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1051			printk("NLMSG_ERROR: %d: %s",
1052				answer.error, strerror(-answer.error));
1053			return -1;
1054		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1055			if (found)
1056				return 0;
1057			printk("didn't find allocated xfrm state in dump");
1058			return -1;
1059		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1060			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1061				found = true;
1062		}
1063	}
1064}
1065
1066static int xfrm_set(int xfrm_sock, uint32_t *seq,
1067		struct in_addr src, struct in_addr dst,
1068		struct in_addr tunsrc, struct in_addr tundst,
1069		struct xfrm_desc *desc)
1070{
1071	int err;
1072
1073	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1074	if (err) {
1075		printk("Failed to add xfrm state");
1076		return -1;
1077	}
1078
1079	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1080	if (err) {
1081		printk("Failed to add xfrm state");
1082		return -1;
1083	}
1084
1085	/* Check dumps for XFRM_MSG_GETSA */
1086	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1087	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1088	if (err) {
1089		printk("Failed to check xfrm state");
1090		return -1;
1091	}
1092
1093	return 0;
1094}
1095
1096static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1097		struct in_addr src, struct in_addr dst, uint8_t dir,
1098		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1099{
1100	struct {
1101		struct nlmsghdr			nh;
1102		struct xfrm_userpolicy_info	info;
1103		char				attrbuf[MAX_PAYLOAD];
1104	} req;
1105	struct xfrm_user_tmpl tmpl;
1106
1107	memset(&req, 0, sizeof(req));
1108	memset(&tmpl, 0, sizeof(tmpl));
1109	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1110	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1111	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1112	req.nh.nlmsg_seq	= seq;
1113
1114	/* Fill selector. */
1115	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1116	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1117	req.info.sel.family		= AF_INET;
1118	req.info.sel.prefixlen_d	= PREFIX_LEN;
1119	req.info.sel.prefixlen_s	= PREFIX_LEN;
1120
1121	/* Fill lifteme_cfg */
1122	req.info.lft.soft_byte_limit	= XFRM_INF;
1123	req.info.lft.hard_byte_limit	= XFRM_INF;
1124	req.info.lft.soft_packet_limit	= XFRM_INF;
1125	req.info.lft.hard_packet_limit	= XFRM_INF;
1126
1127	req.info.dir = dir;
1128
1129	/* Fill tmpl */
1130	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1131	/* Note: zero-spi cannot be deleted */
1132	tmpl.id.spi = spi;
1133	tmpl.id.proto	= proto;
1134	tmpl.family	= AF_INET;
1135	memcpy(&tmpl.saddr, &src, sizeof(src));
1136	tmpl.mode	= XFRM_MODE_TUNNEL;
1137	tmpl.aalgos = (~(uint32_t)0);
1138	tmpl.ealgos = (~(uint32_t)0);
1139	tmpl.calgos = (~(uint32_t)0);
1140
1141	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1142		return -1;
1143
1144	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1145		pr_err("send()");
1146		return -1;
1147	}
1148
1149	return netlink_check_answer(xfrm_sock);
1150}
1151
1152static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1153		struct in_addr src, struct in_addr dst,
1154		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1155{
1156	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1157				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1158		printk("Failed to add xfrm policy");
1159		return -1;
1160	}
1161
1162	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1163				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1164		printk("Failed to add xfrm policy");
1165		return -1;
1166	}
1167
1168	return 0;
1169}
1170
1171static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1172		struct in_addr src, struct in_addr dst, uint8_t dir,
1173		struct in_addr tunsrc, struct in_addr tundst)
1174{
1175	struct {
1176		struct nlmsghdr			nh;
1177		struct xfrm_userpolicy_id	id;
1178		char				attrbuf[MAX_PAYLOAD];
1179	} req;
1180
1181	memset(&req, 0, sizeof(req));
1182	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1183	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1184	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1185	req.nh.nlmsg_seq	= seq;
1186
1187	/* Fill id */
1188	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1189	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1190	req.id.sel.family		= AF_INET;
1191	req.id.sel.prefixlen_d		= PREFIX_LEN;
1192	req.id.sel.prefixlen_s		= PREFIX_LEN;
1193	req.id.dir = dir;
1194
1195	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1196		pr_err("send()");
1197		return -1;
1198	}
1199
1200	return netlink_check_answer(xfrm_sock);
1201}
1202
1203static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1204		struct in_addr src, struct in_addr dst,
1205		struct in_addr tunsrc, struct in_addr tundst)
1206{
1207	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1208				XFRM_POLICY_OUT, tunsrc, tundst)) {
1209		printk("Failed to add xfrm policy");
1210		return -1;
1211	}
1212
1213	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1214				XFRM_POLICY_IN, tunsrc, tundst)) {
1215		printk("Failed to add xfrm policy");
1216		return -1;
1217	}
1218
1219	return 0;
1220}
1221
1222static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1223		struct in_addr src, struct in_addr dst, uint8_t proto)
1224{
1225	struct {
1226		struct nlmsghdr		nh;
1227		struct xfrm_usersa_id	id;
1228		char			attrbuf[MAX_PAYLOAD];
1229	} req;
1230	xfrm_address_t saddr = {};
1231
1232	memset(&req, 0, sizeof(req));
1233	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1234	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1235	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1236	req.nh.nlmsg_seq	= seq;
1237
1238	memcpy(&req.id.daddr, &dst, sizeof(dst));
1239	req.id.family		= AF_INET;
1240	req.id.proto		= proto;
1241	/* Note: zero-spi cannot be deleted */
1242	req.id.spi = spi;
1243
1244	memcpy(&saddr, &src, sizeof(src));
1245	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1246		return -1;
1247
1248	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1249		pr_err("send()");
1250		return -1;
1251	}
1252
1253	return netlink_check_answer(xfrm_sock);
1254}
1255
1256static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1257		struct in_addr src, struct in_addr dst,
1258		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1259{
1260	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1261		printk("Failed to remove xfrm state");
1262		return -1;
1263	}
1264
1265	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1266		printk("Failed to remove xfrm state");
1267		return -1;
1268	}
1269
1270	return 0;
1271}
1272
1273static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1274		uint32_t spi, uint8_t proto)
1275{
1276	struct {
1277		struct nlmsghdr			nh;
1278		struct xfrm_userspi_info	spi;
1279	} req;
1280	struct {
1281		struct nlmsghdr			nh;
1282		union {
1283			struct xfrm_usersa_info	info;
1284			int error;
1285		};
1286	} answer;
1287
1288	memset(&req, 0, sizeof(req));
1289	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1290	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1291	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1292	req.nh.nlmsg_seq	= (*seq)++;
1293
1294	req.spi.info.family	= AF_INET;
1295	req.spi.min		= spi;
1296	req.spi.max		= spi;
1297	req.spi.info.id.proto	= proto;
1298
1299	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1300		pr_err("send()");
1301		return KSFT_FAIL;
1302	}
1303
1304	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1305		pr_err("recv()");
1306		return KSFT_FAIL;
1307	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1308		uint32_t new_spi = htonl(answer.info.id.spi);
1309
1310		if (new_spi != spi) {
1311			printk("allocated spi is different from requested: %#x != %#x",
1312					new_spi, spi);
1313			return KSFT_FAIL;
1314		}
1315		return KSFT_PASS;
1316	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1317		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1318		return KSFT_FAIL;
1319	}
1320
1321	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1322	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1323}
1324
1325static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1326{
1327	struct sockaddr_nl snl = {};
1328	socklen_t addr_len;
1329	int ret = -1;
1330
1331	snl.nl_family = AF_NETLINK;
1332	snl.nl_groups = groups;
1333
1334	if (netlink_sock(sock, seq, proto)) {
1335		printk("Failed to open xfrm netlink socket");
1336		return -1;
1337	}
1338
1339	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1340		pr_err("bind()");
1341		goto out_close;
1342	}
1343
1344	addr_len = sizeof(snl);
1345	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1346		pr_err("getsockname()");
1347		goto out_close;
1348	}
1349	if (addr_len != sizeof(snl)) {
1350		printk("Wrong address length %d", addr_len);
1351		goto out_close;
1352	}
1353	if (snl.nl_family != AF_NETLINK) {
1354		printk("Wrong address family %d", snl.nl_family);
1355		goto out_close;
1356	}
1357	return 0;
1358
1359out_close:
1360	close(*sock);
1361	return ret;
1362}
1363
1364static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1365{
1366	struct {
1367		struct nlmsghdr nh;
1368		union {
1369			struct xfrm_user_acquire acq;
1370			int error;
1371		};
1372		char attrbuf[MAX_PAYLOAD];
1373	} req;
1374	struct xfrm_user_tmpl xfrm_tmpl = {};
1375	int xfrm_listen = -1, ret = KSFT_FAIL;
1376	uint32_t seq_listen;
1377
1378	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1379		return KSFT_FAIL;
1380
1381	memset(&req, 0, sizeof(req));
1382	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1383	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1384	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1385	req.nh.nlmsg_seq	= (*seq)++;
1386
1387	req.acq.policy.sel.family	= AF_INET;
1388	req.acq.aalgos	= 0xfeed;
1389	req.acq.ealgos	= 0xbaad;
1390	req.acq.calgos	= 0xbabe;
1391
1392	xfrm_tmpl.family = AF_INET;
1393	xfrm_tmpl.id.proto = IPPROTO_ESP;
1394	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1395		goto out_close;
1396
1397	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1398		pr_err("send()");
1399		goto out_close;
1400	}
1401
1402	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1403		pr_err("recv()");
1404		goto out_close;
1405	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1406		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1407		goto out_close;
1408	}
1409
1410	if (req.error) {
1411		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1412		ret = req.error;
1413		goto out_close;
1414	}
1415
1416	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1417		pr_err("recv()");
1418		goto out_close;
1419	}
1420
1421	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1422			|| req.acq.calgos != 0xbabe) {
1423		printk("xfrm_user_acquire has changed  %x %x %x",
1424				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1425		goto out_close;
1426	}
1427
1428	ret = KSFT_PASS;
1429out_close:
1430	close(xfrm_listen);
1431	return ret;
1432}
1433
1434static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1435		unsigned int nr, struct xfrm_desc *desc)
1436{
1437	struct {
1438		struct nlmsghdr nh;
1439		union {
1440			struct xfrm_user_expire expire;
1441			int error;
1442		};
1443	} req;
1444	struct in_addr src, dst;
1445	int xfrm_listen = -1, ret = KSFT_FAIL;
1446	uint32_t seq_listen;
1447
1448	src = inet_makeaddr(INADDR_B, child_ip(nr));
1449	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1450
1451	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1452		printk("Failed to add xfrm state");
1453		return KSFT_FAIL;
1454	}
1455
1456	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1457		return KSFT_FAIL;
1458
1459	memset(&req, 0, sizeof(req));
1460	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1461	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1462	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1463	req.nh.nlmsg_seq	= (*seq)++;
1464
1465	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1466	req.expire.state.id.spi		= gen_spi(src);
1467	req.expire.state.id.proto	= desc->proto;
1468	req.expire.state.family		= AF_INET;
1469	req.expire.hard			= 0xff;
1470
1471	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1472		pr_err("send()");
1473		goto out_close;
1474	}
1475
1476	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1477		pr_err("recv()");
1478		goto out_close;
1479	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1480		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1481		goto out_close;
1482	}
1483
1484	if (req.error) {
1485		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1486		ret = req.error;
1487		goto out_close;
1488	}
1489
1490	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1491		pr_err("recv()");
1492		goto out_close;
1493	}
1494
1495	if (req.expire.hard != 0x1) {
1496		printk("expire.hard is not set: %x", req.expire.hard);
1497		goto out_close;
1498	}
1499
1500	ret = KSFT_PASS;
1501out_close:
1502	close(xfrm_listen);
1503	return ret;
1504}
1505
1506static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1507		unsigned int nr, struct xfrm_desc *desc)
1508{
1509	struct {
1510		struct nlmsghdr nh;
1511		union {
1512			struct xfrm_user_polexpire expire;
1513			int error;
1514		};
1515	} req;
1516	struct in_addr src, dst, tunsrc, tundst;
1517	int xfrm_listen = -1, ret = KSFT_FAIL;
1518	uint32_t seq_listen;
1519
1520	src = inet_makeaddr(INADDR_B, child_ip(nr));
1521	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1522	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1523	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1524
1525	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1526				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1527		printk("Failed to add xfrm policy");
1528		return KSFT_FAIL;
1529	}
1530
1531	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1532		return KSFT_FAIL;
1533
1534	memset(&req, 0, sizeof(req));
1535	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1536	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1537	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1538	req.nh.nlmsg_seq	= (*seq)++;
1539
1540	/* Fill selector. */
1541	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1542	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1543	req.expire.pol.sel.family	= AF_INET;
1544	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1545	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1546	req.expire.pol.dir		= XFRM_POLICY_OUT;
1547	req.expire.hard			= 0xff;
1548
1549	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1550		pr_err("send()");
1551		goto out_close;
1552	}
1553
1554	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1555		pr_err("recv()");
1556		goto out_close;
1557	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1558		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1559		goto out_close;
1560	}
1561
1562	if (req.error) {
1563		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1564		ret = req.error;
1565		goto out_close;
1566	}
1567
1568	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1569		pr_err("recv()");
1570		goto out_close;
1571	}
1572
1573	if (req.expire.hard != 0x1) {
1574		printk("expire.hard is not set: %x", req.expire.hard);
1575		goto out_close;
1576	}
1577
1578	ret = KSFT_PASS;
1579out_close:
1580	close(xfrm_listen);
1581	return ret;
1582}
1583
1584static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1585		unsigned thresh4_l, unsigned thresh4_r,
1586		unsigned thresh6_l, unsigned thresh6_r,
1587		bool add_bad_attr)
1588
1589{
1590	struct {
1591		struct nlmsghdr		nh;
1592		union {
1593			uint32_t	unused;
1594			int		error;
1595		};
1596		char			attrbuf[MAX_PAYLOAD];
1597	} req;
1598	struct xfrmu_spdhthresh thresh;
1599
1600	memset(&req, 0, sizeof(req));
1601	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1602	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1603	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1604	req.nh.nlmsg_seq	= (*seq)++;
1605
1606	thresh.lbits = thresh4_l;
1607	thresh.rbits = thresh4_r;
1608	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1609		return -1;
1610
1611	thresh.lbits = thresh6_l;
1612	thresh.rbits = thresh6_r;
1613	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1614		return -1;
1615
1616	if (add_bad_attr) {
1617		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1618		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1619			pr_err("adding attribute failed: no space");
1620			return -1;
1621		}
1622	}
1623
1624	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1625		pr_err("send()");
1626		return -1;
1627	}
1628
1629	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1630		pr_err("recv()");
1631		return -1;
1632	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1633		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1634		return -1;
1635	}
1636
1637	if (req.error) {
1638		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1639		return -1;
1640	}
1641
1642	return 0;
1643}
1644
1645static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1646{
1647	struct {
1648		struct nlmsghdr			nh;
1649		union {
1650			uint32_t	unused;
1651			int		error;
1652		};
1653		char			attrbuf[MAX_PAYLOAD];
1654	} req;
1655
1656	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1657		pr_err("Can't set SPD HTHRESH");
1658		return KSFT_FAIL;
1659	}
1660
1661	memset(&req, 0, sizeof(req));
1662
1663	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1664	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1665	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1666	req.nh.nlmsg_seq	= (*seq)++;
1667	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1668		pr_err("send()");
1669		return KSFT_FAIL;
1670	}
1671
1672	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1673		pr_err("recv()");
1674		return KSFT_FAIL;
1675	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1676		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1677		struct rtattr *attr = (void *)req.attrbuf;
1678		int got_thresh = 0;
1679
1680		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1681			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1682				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1683
1684				got_thresh++;
1685				if (t->lbits != 32 || t->rbits != 31) {
1686					pr_err("thresh differ: %u, %u",
1687							t->lbits, t->rbits);
1688					return KSFT_FAIL;
1689				}
1690			}
1691			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1692				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1693
1694				got_thresh++;
1695				if (t->lbits != 120 || t->rbits != 16) {
1696					pr_err("thresh differ: %u, %u",
1697							t->lbits, t->rbits);
1698					return KSFT_FAIL;
1699				}
1700			}
1701		}
1702		if (got_thresh != 2) {
1703			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1704			return KSFT_FAIL;
1705		}
1706	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1707		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1708		return KSFT_FAIL;
1709	} else {
1710		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1711		return -1;
1712	}
1713
1714	/* Restore the default */
1715	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1716		pr_err("Can't restore SPD HTHRESH");
1717		return KSFT_FAIL;
1718	}
1719
1720	/*
1721	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1722	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1723	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1724	 * it. Or even stricter nla_parse().
1725	 * Right now it's not expected to fail, but to be ignored.
1726	 */
1727	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1728		return KSFT_PASS;
1729
1730	return KSFT_PASS;
1731}
1732
1733static int child_serv(int xfrm_sock, uint32_t *seq,
1734		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1735{
1736	struct in_addr src, dst, tunsrc, tundst;
1737	struct test_desc msg;
1738	int ret = KSFT_FAIL;
1739
1740	src = inet_makeaddr(INADDR_B, child_ip(nr));
1741	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1743	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744
1745	/* UDP pinging without xfrm */
1746	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1747		printk("ping failed before setting xfrm");
1748		return KSFT_FAIL;
1749	}
1750
1751	memset(&msg, 0, sizeof(msg));
1752	msg.type = MSG_XFRM_PREPARE;
1753	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1754	write_msg(cmd_fd, &msg, 1);
1755
1756	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1757		printk("failed to prepare xfrm");
1758		goto cleanup;
1759	}
1760
1761	memset(&msg, 0, sizeof(msg));
1762	msg.type = MSG_XFRM_ADD;
1763	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1764	write_msg(cmd_fd, &msg, 1);
1765	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1766		printk("failed to set xfrm");
1767		goto delete;
1768	}
1769
1770	/* UDP pinging with xfrm tunnel */
1771	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1772				true, 0, 0, udp_ping_send)) {
1773		printk("ping failed for xfrm");
1774		goto delete;
1775	}
1776
1777	ret = KSFT_PASS;
1778delete:
1779	/* xfrm delete */
1780	memset(&msg, 0, sizeof(msg));
1781	msg.type = MSG_XFRM_DEL;
1782	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1783	write_msg(cmd_fd, &msg, 1);
1784
1785	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1786		printk("failed ping to remove xfrm");
1787		ret = KSFT_FAIL;
1788	}
1789
1790cleanup:
1791	memset(&msg, 0, sizeof(msg));
1792	msg.type = MSG_XFRM_CLEANUP;
1793	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1794	write_msg(cmd_fd, &msg, 1);
1795	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1796		printk("failed ping to cleanup xfrm");
1797		ret = KSFT_FAIL;
1798	}
1799	return ret;
1800}
1801
1802static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1803{
1804	struct xfrm_desc desc;
1805	struct test_desc msg;
1806	int xfrm_sock = -1;
1807	uint32_t seq;
1808
1809	if (switch_ns(nsfd_childa))
1810		exit(KSFT_FAIL);
1811
1812	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1813		printk("Failed to open xfrm netlink socket");
1814		exit(KSFT_FAIL);
1815	}
1816
1817	/* Check that seq sock is ready, just for sure. */
1818	memset(&msg, 0, sizeof(msg));
1819	msg.type = MSG_ACK;
1820	write_msg(cmd_fd, &msg, 1);
1821	read_msg(cmd_fd, &msg, 1);
1822	if (msg.type != MSG_ACK) {
1823		printk("Ack failed");
1824		exit(KSFT_FAIL);
1825	}
1826
1827	for (;;) {
1828		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1829		int ret;
1830
1831		if (received == 0) /* EOF */
1832			break;
1833
1834		if (received != sizeof(desc)) {
1835			pr_err("read() returned %zd", received);
1836			exit(KSFT_FAIL);
1837		}
1838
1839		switch (desc.type) {
1840		case CREATE_TUNNEL:
1841			ret = child_serv(xfrm_sock, &seq, nr,
1842					 cmd_fd, buf, &desc);
1843			break;
1844		case ALLOCATE_SPI:
1845			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1846						  -1, desc.proto);
1847			break;
1848		case MONITOR_ACQUIRE:
1849			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1850			break;
1851		case EXPIRE_STATE:
1852			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1853			break;
1854		case EXPIRE_POLICY:
1855			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1856			break;
1857		case SPDINFO_ATTRS:
1858			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1859			break;
1860		default:
1861			printk("Unknown desc type %d", desc.type);
1862			exit(KSFT_FAIL);
1863		}
1864		write_test_result(ret, &desc);
1865	}
1866
1867	close(xfrm_sock);
1868
1869	msg.type = MSG_EXIT;
1870	write_msg(cmd_fd, &msg, 1);
1871	exit(KSFT_PASS);
1872}
1873
1874static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1875		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1876{
1877	struct in_addr src, dst, tunsrc, tundst;
1878	bool tun_reply;
1879	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1880
1881	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1882	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1883	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1884	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1885
1886	switch (msg->type) {
1887	case MSG_EXIT:
1888		exit(KSFT_PASS);
1889	case MSG_ACK:
1890		write_msg(cmd_fd, msg, 1);
1891		break;
1892	case MSG_PING:
1893		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1894		/* UDP pinging without xfrm */
1895		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1896				false, msg->body.ping.port,
1897				msg->body.ping.reply_ip, udp_ping_reply)) {
1898			printk("ping failed before setting xfrm");
1899		}
1900		break;
1901	case MSG_XFRM_PREPARE:
1902		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1903					desc->proto)) {
1904			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1905			printk("failed to prepare xfrm");
1906		}
1907		break;
1908	case MSG_XFRM_ADD:
1909		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1910			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1911			printk("failed to set xfrm");
1912		}
1913		break;
1914	case MSG_XFRM_DEL:
1915		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1916					desc->proto)) {
1917			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1918			printk("failed to remove xfrm");
1919		}
1920		break;
1921	case MSG_XFRM_CLEANUP:
1922		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1923			printk("failed to cleanup xfrm");
1924		}
1925		break;
1926	default:
1927		printk("got unknown msg type %d", msg->type);
1928	}
1929}
1930
1931static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1932{
1933	struct test_desc msg;
1934	int xfrm_sock = -1;
1935	uint32_t seq;
1936
1937	if (switch_ns(nsfd_childb))
1938		exit(KSFT_FAIL);
1939
1940	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1941		printk("Failed to open xfrm netlink socket");
1942		exit(KSFT_FAIL);
1943	}
1944
1945	do {
1946		read_msg(cmd_fd, &msg, 1);
1947		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1948	} while (1);
1949
1950	close(xfrm_sock);
1951	exit(KSFT_FAIL);
1952}
1953
1954static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1955{
1956	int cmd_sock[2];
1957	void *data_map;
1958	pid_t child;
1959
1960	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1961		return -1;
1962
1963	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1964		return -1;
1965
1966	child = fork();
1967	if (child < 0) {
1968		pr_err("fork()");
1969		return -1;
1970	} else if (child) {
1971		/* in parent - selftest */
1972		return switch_ns(nsfd_parent);
1973	}
1974
1975	if (close(test_desc_fd[1])) {
1976		pr_err("close()");
1977		return -1;
1978	}
1979
1980	/* child */
1981	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1982			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1983	if (data_map == MAP_FAILED) {
1984		pr_err("mmap()");
1985		return -1;
1986	}
1987
1988	randomize_buffer(data_map, page_size);
1989
1990	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1991		pr_err("socketpair()");
1992		return -1;
1993	}
1994
1995	child = fork();
1996	if (child < 0) {
1997		pr_err("fork()");
1998		return -1;
1999	} else if (child) {
2000		if (close(cmd_sock[0])) {
2001			pr_err("close()");
2002			return -1;
2003		}
2004		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2005	}
2006	if (close(cmd_sock[1])) {
2007		pr_err("close()");
2008		return -1;
2009	}
2010	return grand_child_f(nr, cmd_sock[0], data_map);
2011}
2012
2013static void exit_usage(char **argv)
2014{
2015	printk("Usage: %s [nr_process]", argv[0]);
2016	exit(KSFT_FAIL);
2017}
2018
2019static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2020{
2021	ssize_t ret;
2022
2023	ret = write(test_desc_fd, desc, sizeof(*desc));
2024
2025	if (ret == sizeof(*desc))
2026		return 0;
2027
2028	pr_err("Writing test's desc failed %ld", ret);
2029
2030	return -1;
2031}
2032
2033static int write_desc(int proto, int test_desc_fd,
2034		char *a, char *e, char *c, char *ae)
2035{
2036	struct xfrm_desc desc = {};
2037
2038	desc.type = CREATE_TUNNEL;
2039	desc.proto = proto;
2040
2041	if (a)
2042		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2043	if (e)
2044		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2045	if (c)
2046		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2047	if (ae)
2048		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2049
2050	return __write_desc(test_desc_fd, &desc);
2051}
2052
2053int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2054char *ah_list[] = {
2055	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2056	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2057	"xcbc(aes)", "cmac(aes)"
2058};
2059char *comp_list[] = {
2060	"deflate",
2061#if 0
2062	/* No compression backend realization */
2063	"lzs", "lzjh"
2064#endif
2065};
2066char *e_list[] = {
2067	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2068	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2069	"cbc(twofish)", "rfc3686(ctr(aes))"
2070};
2071char *ae_list[] = {
2072#if 0
2073	/* not implemented */
2074	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2075	"rfc7539esp(chacha20,poly1305)"
2076#endif
2077};
2078
2079const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2080				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2081				+ ARRAY_SIZE(ae_list);
2082
2083static int write_proto_plan(int fd, int proto)
2084{
2085	unsigned int i;
2086
2087	switch (proto) {
2088	case IPPROTO_AH:
2089		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2090			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2091				return -1;
2092		}
2093		break;
2094	case IPPROTO_COMP:
2095		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2096			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2097				return -1;
2098		}
2099		break;
2100	case IPPROTO_ESP:
2101		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2102			int j;
2103
2104			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2105				if (write_desc(proto, fd, ah_list[i],
2106							e_list[j], 0, 0))
2107					return -1;
2108			}
2109		}
2110		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2111			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2112				return -1;
2113		}
2114		break;
2115	default:
2116		printk("BUG: Specified unknown proto %d", proto);
2117		return -1;
2118	}
2119
2120	return 0;
2121}
2122
2123/*
2124 * Some structures in xfrm uapi header differ in size between
2125 * 64-bit and 32-bit ABI:
2126 *
2127 *             32-bit UABI               |            64-bit UABI
2128 *  -------------------------------------|-------------------------------------
2129 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2130 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2131 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2132 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2133 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2134 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2135 *
2136 * Check the affected by the UABI difference structures.
2137 * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2138 * which needs to be correctly copied, but not translated.
2139 */
2140const unsigned int compat_plan = 5;
2141static int write_compat_struct_tests(int test_desc_fd)
2142{
2143	struct xfrm_desc desc = {};
2144
2145	desc.type = ALLOCATE_SPI;
2146	desc.proto = IPPROTO_AH;
2147	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2148
2149	if (__write_desc(test_desc_fd, &desc))
2150		return -1;
2151
2152	desc.type = MONITOR_ACQUIRE;
2153	if (__write_desc(test_desc_fd, &desc))
2154		return -1;
2155
2156	desc.type = EXPIRE_STATE;
2157	if (__write_desc(test_desc_fd, &desc))
2158		return -1;
2159
2160	desc.type = EXPIRE_POLICY;
2161	if (__write_desc(test_desc_fd, &desc))
2162		return -1;
2163
2164	desc.type = SPDINFO_ATTRS;
2165	if (__write_desc(test_desc_fd, &desc))
2166		return -1;
2167
2168	return 0;
2169}
2170
2171static int write_test_plan(int test_desc_fd)
2172{
2173	unsigned int i;
2174	pid_t child;
2175
2176	child = fork();
2177	if (child < 0) {
2178		pr_err("fork()");
2179		return -1;
2180	}
2181	if (child) {
2182		if (close(test_desc_fd))
2183			printk("close(): %m");
2184		return 0;
2185	}
2186
2187	if (write_compat_struct_tests(test_desc_fd))
2188		exit(KSFT_FAIL);
2189
2190	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2191		if (write_proto_plan(test_desc_fd, proto_list[i]))
2192			exit(KSFT_FAIL);
2193	}
2194
2195	exit(KSFT_PASS);
2196}
2197
2198static int children_cleanup(void)
2199{
2200	unsigned ret = KSFT_PASS;
2201
2202	while (1) {
2203		int status;
2204		pid_t p = wait(&status);
2205
2206		if ((p < 0) && errno == ECHILD)
2207			break;
2208
2209		if (p < 0) {
2210			pr_err("wait()");
2211			return KSFT_FAIL;
2212		}
2213
2214		if (!WIFEXITED(status)) {
2215			ret = KSFT_FAIL;
2216			continue;
2217		}
2218
2219		if (WEXITSTATUS(status) == KSFT_FAIL)
2220			ret = KSFT_FAIL;
2221	}
2222
2223	return ret;
2224}
2225
2226typedef void (*print_res)(const char *, ...);
2227
2228static int check_results(void)
2229{
2230	struct test_result tr = {};
2231	struct xfrm_desc *d = &tr.desc;
2232	int ret = KSFT_PASS;
2233
2234	while (1) {
2235		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2236		print_res result;
2237
2238		if (received == 0) /* EOF */
2239			break;
2240
2241		if (received != sizeof(tr)) {
2242			pr_err("read() returned %zd", received);
2243			return KSFT_FAIL;
2244		}
2245
2246		switch (tr.res) {
2247		case KSFT_PASS:
2248			result = ksft_test_result_pass;
2249			break;
2250		case KSFT_FAIL:
2251		default:
2252			result = ksft_test_result_fail;
2253			ret = KSFT_FAIL;
2254		}
2255
2256		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2257		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2258		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2259	}
2260
2261	return ret;
2262}
2263
2264int main(int argc, char **argv)
2265{
2266	long nr_process = 1;
2267	int route_sock = -1, ret = KSFT_SKIP;
2268	int test_desc_fd[2];
2269	uint32_t route_seq;
2270	unsigned int i;
2271
2272	if (argc > 2)
2273		exit_usage(argv);
2274
2275	if (argc > 1) {
2276		char *endptr;
2277
2278		errno = 0;
2279		nr_process = strtol(argv[1], &endptr, 10);
2280		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2281				|| (errno != 0 && nr_process == 0)
2282				|| (endptr == argv[1]) || (*endptr != '\0')) {
2283			printk("Failed to parse [nr_process]");
2284			exit_usage(argv);
2285		}
2286
2287		if (nr_process > MAX_PROCESSES || nr_process < 1) {
2288			printk("nr_process should be between [1; %u]",
2289					MAX_PROCESSES);
2290			exit_usage(argv);
2291		}
2292	}
2293
2294	srand(time(NULL));
2295	page_size = sysconf(_SC_PAGESIZE);
2296	if (page_size < 1)
2297		ksft_exit_skip("sysconf(): %m\n");
2298
2299	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2300		ksft_exit_skip("pipe(): %m\n");
2301
2302	if (pipe2(results_fd, O_DIRECT) < 0)
2303		ksft_exit_skip("pipe(): %m\n");
2304
2305	if (init_namespaces())
2306		ksft_exit_skip("Failed to create namespaces\n");
2307
2308	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2309		ksft_exit_skip("Failed to open netlink route socket\n");
2310
2311	for (i = 0; i < nr_process; i++) {
2312		char veth[VETH_LEN];
2313
2314		snprintf(veth, VETH_LEN, VETH_FMT, i);
2315
2316		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2317			close(route_sock);
2318			ksft_exit_fail_msg("Failed to create veth device");
2319		}
2320
2321		if (start_child(i, veth, test_desc_fd)) {
2322			close(route_sock);
2323			ksft_exit_fail_msg("Child %u failed to start", i);
2324		}
2325	}
2326
2327	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2328		ksft_exit_fail_msg("close(): %m");
2329
2330	ksft_set_plan(proto_plan + compat_plan);
2331
2332	if (write_test_plan(test_desc_fd[1]))
2333		ksft_exit_fail_msg("Failed to write test plan to pipe");
2334
2335	ret = check_results();
2336
2337	if (children_cleanup() == KSFT_FAIL)
2338		exit(KSFT_FAIL);
2339
2340	exit(ret);
2341}
2342