1// SPDX-License-Identifier: GPL-2.0
2#include <alloca.h>
3#include <fcntl.h>
4#include <inttypes.h>
5#include <string.h>
6#include "../../../../../include/linux/kernel.h"
7#include "../../../../../include/linux/stringify.h"
8#include "aolib.h"
9
10const unsigned int test_server_port = 7010;
11int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12{
13	int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14	long flags;
15
16	if (sk < 0)
17		test_error("socket()");
18
19	err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20			 strlen(veth_name) + 1);
21	if (err < 0)
22		test_error("setsockopt(SO_BINDTODEVICE)");
23
24	if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25		test_error("bind()");
26
27	flags = fcntl(sk, F_GETFL);
28	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29		test_error("fcntl()");
30
31	if (listen(sk, backlog))
32		test_error("listen()");
33
34	return sk;
35}
36
37int test_wait_fd(int sk, time_t sec, bool write)
38{
39	struct timeval tv = { .tv_sec = sec };
40	struct timeval *ptv = NULL;
41	fd_set fds, efds;
42	int ret;
43	socklen_t slen = sizeof(ret);
44
45	FD_ZERO(&fds);
46	FD_SET(sk, &fds);
47	FD_ZERO(&efds);
48	FD_SET(sk, &efds);
49
50	if (sec)
51		ptv = &tv;
52
53	errno = 0;
54	if (write)
55		ret = select(sk + 1, NULL, &fds, &efds, ptv);
56	else
57		ret = select(sk + 1, &fds, NULL, &efds, ptv);
58	if (ret < 0)
59		return -errno;
60	if (ret == 0) {
61		errno = ETIMEDOUT;
62		return -ETIMEDOUT;
63	}
64
65	if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
66		return -errno;
67	if (ret)
68		return -ret;
69	return 0;
70}
71
72int __test_connect_socket(int sk, const char *device,
73			  void *addr, size_t addr_sz, time_t timeout)
74{
75	long flags;
76	int err;
77
78	if (device != NULL) {
79		err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
80				 strlen(device) + 1);
81		if (err < 0)
82			test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
83	}
84
85	if (!timeout) {
86		err = connect(sk, addr, addr_sz);
87		if (err) {
88			err = -errno;
89			goto out;
90		}
91		return 0;
92	}
93
94	flags = fcntl(sk, F_GETFL);
95	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
96		test_error("fcntl()");
97
98	if (connect(sk, addr, addr_sz) < 0) {
99		if (errno != EINPROGRESS) {
100			err = -errno;
101			goto out;
102		}
103		if (timeout < 0)
104			return sk;
105		err = test_wait_fd(sk, timeout, 1);
106		if (err)
107			goto out;
108	}
109	return sk;
110
111out:
112	close(sk);
113	return err;
114}
115
116int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
117		   int vrf, const char *password)
118{
119	size_t pwd_len = strlen(password);
120	struct tcp_md5sig md5sig = {};
121
122	md5sig.tcpm_keylen = pwd_len;
123	memcpy(md5sig.tcpm_key, password, pwd_len);
124	md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
125	md5sig.tcpm_prefixlen = prefix;
126	if (vrf >= 0) {
127		md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
128		md5sig.tcpm_ifindex = (uint8_t)vrf;
129	}
130	memcpy(&md5sig.tcpm_addr, addr, addr_sz);
131
132	errno = 0;
133	return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
134			&md5sig, sizeof(md5sig));
135}
136
137
138int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
139		void *addr, size_t addr_sz, bool set_current, bool set_rnext,
140		uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
141		uint8_t maclen, uint8_t keyflags,
142		uint8_t keylen, const char *key)
143{
144	memset(ao, 0, sizeof(struct tcp_ao_add));
145
146	ao->set_current	= !!set_current;
147	ao->set_rnext	= !!set_rnext;
148	ao->prefix	= prefix;
149	ao->sndid	= sndid;
150	ao->rcvid	= rcvid;
151	ao->maclen	= maclen;
152	ao->keyflags	= keyflags;
153	ao->keylen	= keylen;
154	ao->ifindex	= vrf;
155
156	memcpy(&ao->addr, addr, addr_sz);
157
158	if (strlen(alg) > 64)
159		return -ENOBUFS;
160	strncpy(ao->alg_name, alg, 64);
161
162	memcpy(ao->key, key,
163	       (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
164	return 0;
165}
166
167static int test_get_ao_keys_nr(int sk)
168{
169	struct tcp_ao_getsockopt tmp = {};
170	socklen_t tmp_sz = sizeof(tmp);
171	int ret;
172
173	tmp.nkeys  = 1;
174	tmp.get_all = 1;
175
176	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
177	if (ret)
178		return -errno;
179	return (int)tmp.nkeys;
180}
181
182int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
183		void *addr, size_t addr_sz, uint8_t prefix,
184		uint8_t sndid, uint8_t rcvid)
185{
186	struct tcp_ao_getsockopt tmp = {};
187	socklen_t tmp_sz = sizeof(tmp);
188	int ret;
189
190	memcpy(&tmp.addr, addr, addr_sz);
191	tmp.prefix = prefix;
192	tmp.sndid  = sndid;
193	tmp.rcvid  = rcvid;
194	tmp.nkeys  = 1;
195
196	ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
197	if (ret)
198		return ret;
199	if (tmp.nkeys != 1)
200		return -E2BIG;
201	*out = tmp;
202	return 0;
203}
204
205int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
206{
207	socklen_t sz = sizeof(*out);
208
209	out->reserved = 0;
210	out->reserved2 = 0;
211	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
212		return -errno;
213	if (sz != sizeof(*out))
214		return -EMSGSIZE;
215	return 0;
216}
217
218int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
219{
220	socklen_t sz = sizeof(*in);
221
222	in->reserved = 0;
223	in->reserved2 = 0;
224	if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
225		return -errno;
226	return 0;
227}
228
229int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
230				   const struct tcp_ao_getsockopt *b)
231{
232	bool is_kdf_aes_128_cmac = false;
233	bool is_cmac_aes = false;
234
235	if (!strcmp("cmac(aes128)", a->alg_name)) {
236		is_kdf_aes_128_cmac = (a->keylen != 16);
237		is_cmac_aes = true;
238	}
239
240#define __cmp_ao(member)						\
241do {									\
242	if (b->member != a->member) {					\
243		test_fail("getsockopt(): " __stringify(member) " %u != %u",	\
244				b->member, a->member);			\
245		return -1;						\
246	}								\
247} while(0)
248	__cmp_ao(sndid);
249	__cmp_ao(rcvid);
250	__cmp_ao(prefix);
251	__cmp_ao(keyflags);
252	__cmp_ao(ifindex);
253	if (a->maclen) {
254		__cmp_ao(maclen);
255	} else if (b->maclen != 12) {
256		test_fail("getsockopt(): expected default maclen 12, but it's %u",
257				b->maclen);
258		return -1;
259	}
260	if (!is_kdf_aes_128_cmac) {
261		__cmp_ao(keylen);
262	} else if (b->keylen != 16) {
263		test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
264				b->keylen);
265		return -1;
266	}
267#undef __cmp_ao
268	if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
269		test_fail("getsockopt(): returned key is different `%s' != `%s'",
270				b->key, a->key);
271		return -1;
272	}
273	if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
274		test_fail("getsockopt(): returned address is different");
275		return -1;
276	}
277	if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
278		test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
279		return -1;
280	}
281	if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
282		test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
283		return -1;
284	}
285	/* For a established key rotation test don't add a key with
286	 * set_current = 1, as it's likely to change by peer's request;
287	 * rather use setsockopt(TCP_AO_INFO)
288	 */
289	if (a->set_current != b->is_current) {
290		test_fail("getsockopt(): returned key is not Current_key");
291		return -1;
292	}
293	if (a->set_rnext != b->is_rnext) {
294		test_fail("getsockopt(): returned key is not RNext_key");
295		return -1;
296	}
297
298	return 0;
299}
300
301int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
302				      const struct tcp_ao_info_opt *b)
303{
304	/* No check for ::current_key, as it may change by the peer */
305	if (a->ao_required != b->ao_required) {
306		test_fail("getsockopt(): returned ao doesn't have ao_required");
307		return -1;
308	}
309	if (a->accept_icmps != b->accept_icmps) {
310		test_fail("getsockopt(): returned ao doesn't accept ICMPs");
311		return -1;
312	}
313	if (a->set_rnext && a->rnext != b->rnext) {
314		test_fail("getsockopt(): RNext KeyID has changed");
315		return -1;
316	}
317#define __cmp_cnt(member)						\
318do {									\
319	if (b->member != a->member) {					\
320		test_fail("getsockopt(): " __stringify(member) " %llu != %llu",	\
321				b->member, a->member);			\
322		return -1;						\
323	}								\
324} while(0)
325	if (a->set_counters) {
326		__cmp_cnt(pkt_good);
327		__cmp_cnt(pkt_bad);
328		__cmp_cnt(pkt_key_not_found);
329		__cmp_cnt(pkt_ao_required);
330		__cmp_cnt(pkt_dropped_icmp);
331	}
332#undef __cmp_cnt
333	return 0;
334}
335
336int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
337{
338	struct tcp_ao_getsockopt *key_dump;
339	socklen_t key_dump_sz = sizeof(*key_dump);
340	struct tcp_ao_info_opt info = {};
341	bool c1, c2, c3, c4, c5;
342	struct netstat *ns;
343	int err, nr_keys;
344
345	memset(out, 0, sizeof(*out));
346
347	/* per-netns */
348	ns = netstat_read();
349	out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
350	out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
351	out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
352	out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
353	out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
354	netstat_free(ns);
355	if (c1 || c2 || c3 || c4 || c5)
356		return -EOPNOTSUPP;
357
358	err = test_get_ao_info(sk, &info);
359	if (err)
360		return err;
361
362	/* per-socket */
363	out->ao_info_pkt_good		= info.pkt_good;
364	out->ao_info_pkt_bad		= info.pkt_bad;
365	out->ao_info_pkt_key_not_found	= info.pkt_key_not_found;
366	out->ao_info_pkt_ao_required	= info.pkt_ao_required;
367	out->ao_info_pkt_dropped_icmp	= info.pkt_dropped_icmp;
368
369	/* per-key */
370	nr_keys = test_get_ao_keys_nr(sk);
371	if (nr_keys < 0)
372		return nr_keys;
373	if (nr_keys == 0)
374		test_error("test_get_ao_keys_nr() == 0");
375	out->nr_keys = (size_t)nr_keys;
376	key_dump = calloc(nr_keys, key_dump_sz);
377	if (!key_dump)
378		return -errno;
379
380	key_dump[0].nkeys = nr_keys;
381	key_dump[0].get_all = 1;
382	key_dump[0].get_all = 1;
383	err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
384			 key_dump, &key_dump_sz);
385	if (err) {
386		free(key_dump);
387		return -errno;
388	}
389
390	out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
391	if (!out->key_cnts) {
392		free(key_dump);
393		return -errno;
394	}
395
396	while (nr_keys--) {
397		out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
398		out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
399		out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
400		out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
401	}
402	free(key_dump);
403
404	return 0;
405}
406
407int __test_tcp_ao_counters_cmp(const char *tst_name,
408			       struct tcp_ao_counters *before,
409			       struct tcp_ao_counters *after,
410			       test_cnt expected)
411{
412#define __cmp_ao(cnt, expecting_inc)					\
413do {									\
414	if (before->cnt > after->cnt) {					\
415		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
416			  tst_name ?: "", before->cnt, after->cnt);		\
417		return -1;						\
418	}								\
419	if ((before->cnt != after->cnt) != (expecting_inc)) {		\
420		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
421			  tst_name ?: "", (expecting_inc) ? "" : "not ",	\
422			  before->cnt, after->cnt);			\
423		return -1;						\
424	}								\
425} while(0)
426
427	errno = 0;
428	/* per-netns */
429	__cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
430	__cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
431	__cmp_ao(netns_ao_key_not_found,
432		 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
433	__cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
434	__cmp_ao(netns_ao_dropped_icmp,
435		 !!(expected & TEST_CNT_NS_DROPPED_ICMP));
436	/* per-socket */
437	__cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
438	__cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
439	__cmp_ao(ao_info_pkt_key_not_found,
440		 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
441	__cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
442	__cmp_ao(ao_info_pkt_dropped_icmp,
443		 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
444	return 0;
445#undef __cmp_ao
446}
447
448int test_tcp_ao_key_counters_cmp(const char *tst_name,
449				 struct tcp_ao_counters *before,
450				 struct tcp_ao_counters *after,
451				 test_cnt expected,
452				 int sndid, int rcvid)
453{
454	size_t i;
455#define __cmp_ao(i, cnt, expecting_inc)					\
456do {									\
457	if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) {		\
458		test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
459			  tst_name ?: "", before->key_cnts[i].cnt,	\
460			  after->key_cnts[i].cnt,			\
461			  before->key_cnts[i].sndid,			\
462			  before->key_cnts[i].rcvid);			\
463		return -1;						\
464	}								\
465	if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) {		\
466		test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
467			  tst_name ?: "", (expecting_inc) ? "" : "not ",\
468			  before->key_cnts[i].cnt,			\
469			  after->key_cnts[i].cnt,			\
470			  before->key_cnts[i].sndid,			\
471			  before->key_cnts[i].rcvid);			\
472		return -1;						\
473	}								\
474} while(0)
475
476	if (before->nr_keys != after->nr_keys) {
477		test_fail("%s: Keys changed on the socket %zu != %zu",
478			  tst_name, before->nr_keys, after->nr_keys);
479		return -1;
480	}
481
482	/* per-key */
483	i = before->nr_keys;
484	while (i--) {
485		if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
486			continue;
487		if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
488			continue;
489		__cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
490		__cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
491	}
492	return 0;
493#undef __cmp_ao
494}
495
496void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
497{
498	free(cnts->key_cnts);
499}
500
501#define TEST_BUF_SIZE 4096
502ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
503{
504	ssize_t total = 0;
505
506	do {
507		char buf[TEST_BUF_SIZE];
508		ssize_t bytes, sent;
509		int ret;
510
511		ret = test_wait_fd(sk, timeout_sec, 0);
512		if (ret)
513			return ret;
514
515		bytes = recv(sk, buf, sizeof(buf), 0);
516
517		if (bytes < 0)
518			test_error("recv(): %zd", bytes);
519		if (bytes == 0)
520			break;
521
522		ret = test_wait_fd(sk, timeout_sec, 1);
523		if (ret)
524			return ret;
525
526		sent = send(sk, buf, bytes, 0);
527		if (sent == 0)
528			break;
529		if (sent != bytes)
530			test_error("send()");
531		total += bytes;
532	} while (!quota || total < quota);
533
534	return total;
535}
536
537ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
538			 const size_t msg_len, time_t timeout_sec)
539{
540	char msg[msg_len];
541	int nodelay = 1;
542	size_t i;
543
544	if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
545		test_error("setsockopt(TCP_NODELAY)");
546
547	for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
548		size_t sent, bytes = min(msg_len, buf_sz - i);
549		int ret;
550
551		ret = test_wait_fd(sk, timeout_sec, 1);
552		if (ret)
553			return ret;
554
555		sent = send(sk, buf + i, bytes, 0);
556		if (sent == 0)
557			break;
558		if (sent != bytes)
559			test_error("send()");
560
561		bytes = 0;
562		do {
563			ssize_t got;
564
565			ret = test_wait_fd(sk, timeout_sec, 0);
566			if (ret)
567				return ret;
568
569			got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
570			if (got <= 0)
571				return i;
572			bytes += got;
573		} while (bytes < sent);
574		if (bytes > sent)
575			test_error("recv(): %zd > %zd", bytes, sent);
576		if (memcmp(buf + i, msg, bytes) != 0) {
577			test_fail("received message differs");
578			return -1;
579		}
580	}
581	return i;
582}
583
584int test_client_verify(int sk, const size_t msg_len, const size_t nr,
585		       time_t timeout_sec)
586{
587	size_t buf_sz = msg_len * nr;
588	char *buf = alloca(buf_sz);
589	ssize_t ret;
590
591	randomize_buffer(buf, buf_sz);
592	ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec);
593	if (ret < 0)
594		return (int)ret;
595	return ret != buf_sz ? -1 : 0;
596}
597