1// SPDX-License-Identifier: GPL-2.0
2/* Author: Dmitry Safonov <dima@arista.com> */
3#include <inttypes.h>
4#include "../../../../include/linux/kernel.h"
5#include "aolib.h"
6
7const size_t nr_packets = 20;
8const size_t msg_len = 100;
9const size_t quota = nr_packets * msg_len;
10union tcp_addr wrong_addr;
11#define SECOND_PASSWORD	"at all times sincere friends of freedom have been rare"
12#define fault(type)	(inj == FAULT_ ## type)
13
14static const int test_vrf_ifindex = 200;
15static const uint8_t test_vrf_tabid = 42;
16static void setup_vrfs(void)
17{
18	int err;
19
20	if (!kernel_config_has(KCONFIG_NET_VRF))
21		return;
22
23	err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1);
24	if (err)
25		test_error("Failed to add a VRF: %d", err);
26
27	err = link_set_up("ksft-vrf");
28	if (err)
29		test_error("Failed to bring up a VRF");
30
31	err = ip_route_add_vrf(veth_name, TEST_FAMILY,
32			       this_ip_addr, this_ip_dest, test_vrf_tabid);
33	if (err)
34		test_error("Failed to add a route to VRF");
35}
36
37
38static int prepare_sk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
39{
40	int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
41
42	if (sk < 0)
43		test_error("socket()");
44
45	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
46			 DEFAULT_TEST_PREFIX, 100, 100))
47		test_error("test_add_key()");
48
49	if (addr && test_add_key(sk, SECOND_PASSWORD, *addr,
50				 DEFAULT_TEST_PREFIX, sndid, rcvid))
51		test_error("test_add_key()");
52
53	return sk;
54}
55
56static int prepare_lsk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
57{
58	int sk = prepare_sk(addr, sndid, rcvid);
59
60	if (listen(sk, 10))
61		test_error("listen()");
62
63	return sk;
64}
65
66static int test_del_key(int sk, uint8_t sndid, uint8_t rcvid, bool async,
67			int current_key, int rnext_key)
68{
69	struct tcp_ao_info_opt ao_info = {};
70	struct tcp_ao_getsockopt key = {};
71	struct tcp_ao_del del = {};
72	sockaddr_af sockaddr;
73	int err;
74
75	tcp_addr_to_sockaddr_in(&del.addr, &this_ip_dest, 0);
76	del.prefix = DEFAULT_TEST_PREFIX;
77	del.sndid = sndid;
78	del.rcvid = rcvid;
79
80	if (current_key >= 0) {
81		del.set_current = 1;
82		del.current_key = (uint8_t)current_key;
83	}
84	if (rnext_key >= 0) {
85		del.set_rnext = 1;
86		del.rnext = (uint8_t)rnext_key;
87	}
88
89	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, &del, sizeof(del));
90	if (err < 0)
91		return -errno;
92
93	if (async)
94		return 0;
95
96	tcp_addr_to_sockaddr_in(&sockaddr, &this_ip_dest, 0);
97	err = test_get_one_ao(sk, &key, &sockaddr, sizeof(sockaddr),
98			      DEFAULT_TEST_PREFIX, sndid, rcvid);
99	if (!err)
100		return -EEXIST;
101	if (err != -E2BIG)
102		test_error("getsockopt()");
103	if (current_key < 0 && rnext_key < 0)
104		return 0;
105	if (test_get_ao_info(sk, &ao_info))
106		test_error("getsockopt(TCP_AO_INFO) failed");
107	if (current_key >= 0 && ao_info.current_key != (uint8_t)current_key)
108		return -ENOTRECOVERABLE;
109	if (rnext_key >= 0 && ao_info.rnext != (uint8_t)rnext_key)
110		return -ENOTRECOVERABLE;
111	return 0;
112}
113
114static void try_delete_key(char *tst_name, int sk, uint8_t sndid, uint8_t rcvid,
115			   bool async, int current_key, int rnext_key,
116			   fault_t inj)
117{
118	int err;
119
120	err = test_del_key(sk, sndid, rcvid, async, current_key, rnext_key);
121	if ((err == -EBUSY && fault(BUSY)) || (err == -EINVAL && fault(CURRNEXT))) {
122		test_ok("%s: key deletion was prevented", tst_name);
123		return;
124	}
125	if (err && fault(FIXME)) {
126		test_xfail("%s: failed to delete the key %u:%u %d",
127			   tst_name, sndid, rcvid, err);
128		return;
129	}
130	if (!err) {
131		if (fault(BUSY) || fault(CURRNEXT)) {
132			test_fail("%s: the key was deleted %u:%u %d", tst_name,
133				  sndid, rcvid, err);
134		} else {
135			test_ok("%s: the key was deleted", tst_name);
136		}
137		return;
138	}
139	test_fail("%s: can't delete the key %u:%u %d", tst_name, sndid, rcvid, err);
140}
141
142static int test_set_key(int sk, int current_keyid, int rnext_keyid)
143{
144	struct tcp_ao_info_opt ao_info = {};
145	int err;
146
147	if (current_keyid >= 0) {
148		ao_info.set_current = 1;
149		ao_info.current_key = (uint8_t)current_keyid;
150	}
151	if (rnext_keyid >= 0) {
152		ao_info.set_rnext = 1;
153		ao_info.rnext = (uint8_t)rnext_keyid;
154	}
155
156	err = test_set_ao_info(sk, &ao_info);
157	if (err)
158		return err;
159	if (test_get_ao_info(sk, &ao_info))
160		test_error("getsockopt(TCP_AO_INFO) failed");
161	if (current_keyid >= 0 && ao_info.current_key != (uint8_t)current_keyid)
162		return -ENOTRECOVERABLE;
163	if (rnext_keyid >= 0 && ao_info.rnext != (uint8_t)rnext_keyid)
164		return -ENOTRECOVERABLE;
165	return 0;
166}
167
168static int test_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
169				      union tcp_addr in_addr, uint8_t prefix,
170				      bool set_current, bool set_rnext,
171				      uint8_t sndid, uint8_t rcvid)
172{
173	struct tcp_ao_add tmp = {};
174	int err;
175
176	err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr,
177			       set_current, set_rnext,
178			       prefix, 0, sndid, rcvid, 0, keyflags,
179			       strlen(key), key);
180	if (err)
181		return err;
182
183
184	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
185	if (err < 0)
186		return -errno;
187
188	return test_verify_socket_key(sk, &tmp);
189}
190
191static int __try_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
192				       union tcp_addr in_addr, uint8_t prefix,
193				       bool set_current, bool set_rnext,
194				       uint8_t sndid, uint8_t rcvid)
195{
196	struct tcp_ao_info_opt ao_info = {};
197	int err;
198
199	err = test_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
200					 set_current, set_rnext, sndid, rcvid);
201	if (err)
202		return err;
203
204	if (test_get_ao_info(sk, &ao_info))
205		test_error("getsockopt(TCP_AO_INFO) failed");
206	if (set_current && ao_info.current_key != sndid)
207		return -ENOTRECOVERABLE;
208	if (set_rnext && ao_info.rnext != rcvid)
209		return -ENOTRECOVERABLE;
210	return 0;
211}
212
213static void try_add_current_rnext_key(char *tst_name, int sk, const char *key,
214				     uint8_t keyflags,
215				     union tcp_addr in_addr, uint8_t prefix,
216				     bool set_current, bool set_rnext,
217				     uint8_t sndid, uint8_t rcvid, fault_t inj)
218{
219	int err;
220
221	err = __try_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
222					  set_current, set_rnext, sndid, rcvid);
223	if (!err && !fault(CURRNEXT)) {
224		test_ok("%s", tst_name);
225		return;
226	}
227	if (err == -EINVAL && fault(CURRNEXT)) {
228		test_ok("%s", tst_name);
229		return;
230	}
231	test_fail("%s", tst_name);
232}
233
234static void check_closed_socket(void)
235{
236	int sk;
237
238	sk = prepare_sk(&this_ip_dest, 200, 200);
239	try_delete_key("closed socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
240	try_delete_key("closed socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
241	close(sk);
242
243	sk = prepare_sk(&this_ip_dest, 200, 200);
244	if (test_set_key(sk, 100, 200))
245		test_error("failed to set current/rnext keys");
246	try_delete_key("closed socket, delete current key", sk, 100, 100, 0, -1, -1, FAULT_BUSY);
247	try_delete_key("closed socket, delete rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
248	close(sk);
249
250	sk = prepare_sk(&this_ip_dest, 200, 200);
251	if (test_add_key(sk, "Glory to heros!", this_ip_dest,
252			 DEFAULT_TEST_PREFIX, 10, 11))
253		test_error("test_add_key()");
254	if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
255			 DEFAULT_TEST_PREFIX, 12, 13))
256		test_error("test_add_key()");
257	try_delete_key("closed socket, delete a key + set current/rnext", sk, 100, 100, 0, 10, 13, 0);
258	try_delete_key("closed socket, force-delete current key", sk, 10, 11, 0, 200, -1, 0);
259	try_delete_key("closed socket, force-delete rnext key", sk, 12, 13, 0, -1, 200, 0);
260	try_delete_key("closed socket, delete current+rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
261	close(sk);
262
263	sk = prepare_sk(&this_ip_dest, 200, 200);
264	if (test_set_key(sk, 100, 200))
265		test_error("failed to set current/rnext keys");
266	try_add_current_rnext_key("closed socket, add + change current key",
267				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
268				  this_ip_dest, DEFAULT_TEST_PREFIX,
269				  true, false, 10, 20, 0);
270	try_add_current_rnext_key("closed socket, add + change rnext key",
271				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
272				  this_ip_dest, DEFAULT_TEST_PREFIX,
273				  false, true, 20, 10, 0);
274	close(sk);
275}
276
277static void assert_no_current_rnext(const char *tst_msg, int sk)
278{
279	struct tcp_ao_info_opt ao_info = {};
280
281	if (test_get_ao_info(sk, &ao_info))
282		test_error("getsockopt(TCP_AO_INFO) failed");
283
284	errno = 0;
285	if (ao_info.set_current || ao_info.set_rnext) {
286		test_xfail("%s: the socket has current/rnext keys: %d:%d",
287			   tst_msg,
288			   (ao_info.set_current) ? ao_info.current_key : -1,
289			   (ao_info.set_rnext) ? ao_info.rnext : -1);
290	} else {
291		test_ok("%s: the socket has no current/rnext keys", tst_msg);
292	}
293}
294
295static void assert_no_tcp_repair(void)
296{
297	struct tcp_ao_repair ao_img = {};
298	socklen_t len = sizeof(ao_img);
299	int sk, err;
300
301	sk = prepare_sk(&this_ip_dest, 200, 200);
302	test_enable_repair(sk);
303	if (listen(sk, 10))
304		test_error("listen()");
305	errno = 0;
306	err = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, &len);
307	if (err && errno == EPERM)
308		test_ok("listen socket, getsockopt(TCP_AO_REPAIR) is restricted");
309	else
310		test_fail("listen socket, getsockopt(TCP_AO_REPAIR) works");
311	errno = 0;
312	err = setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, sizeof(ao_img));
313	if (err && errno == EPERM)
314		test_ok("listen socket, setsockopt(TCP_AO_REPAIR) is restricted");
315	else
316		test_fail("listen socket, setsockopt(TCP_AO_REPAIR) works");
317	close(sk);
318}
319
320static void check_listen_socket(void)
321{
322	int sk, err;
323
324	sk = prepare_lsk(&this_ip_dest, 200, 200);
325	try_delete_key("listen socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
326	try_delete_key("listen socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
327	close(sk);
328
329	sk = prepare_lsk(&this_ip_dest, 200, 200);
330	err = test_set_key(sk, 100, -1);
331	if (err == -EINVAL)
332		test_ok("listen socket, setting current key not allowed");
333	else
334		test_fail("listen socket, set current key");
335	err = test_set_key(sk, -1, 200);
336	if (err == -EINVAL)
337		test_ok("listen socket, setting rnext key not allowed");
338	else
339		test_fail("listen socket, set rnext key");
340	close(sk);
341
342	sk = prepare_sk(&this_ip_dest, 200, 200);
343	if (test_set_key(sk, 100, 200))
344		test_error("failed to set current/rnext keys");
345	if (listen(sk, 10))
346		test_error("listen()");
347	assert_no_current_rnext("listen() after current/rnext keys set", sk);
348	try_delete_key("listen socket, delete current key from before listen()", sk, 100, 100, 0, -1, -1, FAULT_FIXME);
349	try_delete_key("listen socket, delete rnext key from before listen()", sk, 200, 200, 0, -1, -1, FAULT_FIXME);
350	close(sk);
351
352	assert_no_tcp_repair();
353
354	sk = prepare_lsk(&this_ip_dest, 200, 200);
355	if (test_add_key(sk, "Glory to heros!", this_ip_dest,
356			 DEFAULT_TEST_PREFIX, 10, 11))
357		test_error("test_add_key()");
358	if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
359			 DEFAULT_TEST_PREFIX, 12, 13))
360		test_error("test_add_key()");
361	try_delete_key("listen socket, delete a key + set current/rnext", sk,
362		       100, 100, 0, 10, 13, FAULT_CURRNEXT);
363	try_delete_key("listen socket, force-delete current key", sk,
364		       10, 11, 0, 200, -1, FAULT_CURRNEXT);
365	try_delete_key("listen socket, force-delete rnext key", sk,
366		       12, 13, 0, -1, 200, FAULT_CURRNEXT);
367	try_delete_key("listen socket, delete a key", sk,
368		       200, 200, 0, -1, -1, 0);
369	close(sk);
370
371	sk = prepare_lsk(&this_ip_dest, 200, 200);
372	try_add_current_rnext_key("listen socket, add + change current key",
373				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
374				  this_ip_dest, DEFAULT_TEST_PREFIX,
375				  true, false, 10, 20, FAULT_CURRNEXT);
376	try_add_current_rnext_key("listen socket, add + change rnext key",
377				  sk, "Laaaa! Lalala-la-la-lalala...", 0,
378				  this_ip_dest, DEFAULT_TEST_PREFIX,
379				  false, true, 20, 10, FAULT_CURRNEXT);
380	close(sk);
381}
382
383static const char *fips_fpath = "/proc/sys/crypto/fips_enabled";
384static bool is_fips_enabled(void)
385{
386	static int fips_checked = -1;
387	FILE *fenabled;
388	int enabled;
389
390	if (fips_checked >= 0)
391		return !!fips_checked;
392	if (access(fips_fpath, R_OK)) {
393		if (errno != ENOENT)
394			test_error("Can't open %s", fips_fpath);
395		fips_checked = 0;
396		return false;
397	}
398	fenabled = fopen(fips_fpath, "r");
399	if (!fenabled)
400		test_error("Can't open %s", fips_fpath);
401	if (fscanf(fenabled, "%d", &enabled) != 1)
402		test_error("Can't read from %s", fips_fpath);
403	fclose(fenabled);
404	fips_checked = !!enabled;
405	return !!fips_checked;
406}
407
408struct test_key {
409	char password[TCP_AO_MAXKEYLEN];
410	const char *alg;
411	unsigned int len;
412	uint8_t client_keyid;
413	uint8_t server_keyid;
414	uint8_t maclen;
415	uint8_t matches_client		: 1,
416		matches_server		: 1,
417		matches_vrf		: 1,
418		is_current		: 1,
419		is_rnext		: 1,
420		used_on_server_tx	: 1,
421		used_on_client_tx	: 1,
422		skip_counters_checks	: 1;
423};
424
425struct key_collection {
426	unsigned int nr_keys;
427	struct test_key *keys;
428};
429
430static struct key_collection collection;
431
432#define TEST_MAX_MACLEN		16
433const char *test_algos[] = {
434	"cmac(aes128)",
435	"hmac(sha1)", "hmac(sha512)", "hmac(sha384)", "hmac(sha256)",
436	"hmac(sha224)", "hmac(sha3-512)",
437	/* only if !CONFIG_FIPS */
438#define TEST_NON_FIPS_ALGOS	2
439	"hmac(rmd160)", "hmac(md5)"
440};
441const unsigned int test_maclens[] = { 1, 4, 12, 16 };
442#define MACLEN_SHIFT		2
443#define ALGOS_SHIFT		4
444
445static unsigned int make_mask(unsigned int shift, unsigned int prev_shift)
446{
447	unsigned int ret = BIT(shift) - 1;
448
449	return ret << prev_shift;
450}
451
452static void init_key_in_collection(unsigned int index, bool randomized)
453{
454	struct test_key *key = &collection.keys[index];
455	unsigned int algos_nr, algos_index;
456
457	/* Same for randomized and non-randomized test flows */
458	key->client_keyid = index;
459	key->server_keyid = 127 + index;
460	key->matches_client = 1;
461	key->matches_server = 1;
462	key->matches_vrf = 1;
463	/* not really even random, but good enough for a test */
464	key->len = rand() % (TCP_AO_MAXKEYLEN - TEST_TCP_AO_MINKEYLEN);
465	key->len += TEST_TCP_AO_MINKEYLEN;
466	randomize_buffer(key->password, key->len);
467
468	if (randomized) {
469		key->maclen = (rand() % TEST_MAX_MACLEN) + 1;
470		algos_index = rand();
471	} else {
472		unsigned int shift = MACLEN_SHIFT;
473
474		key->maclen = test_maclens[index & make_mask(shift, 0)];
475		algos_index = index & make_mask(ALGOS_SHIFT, shift);
476	}
477	algos_nr = ARRAY_SIZE(test_algos);
478	if (is_fips_enabled())
479		algos_nr -= TEST_NON_FIPS_ALGOS;
480	key->alg = test_algos[algos_index % algos_nr];
481}
482
483static int init_default_key_collection(unsigned int nr_keys, bool randomized)
484{
485	size_t key_sz = sizeof(collection.keys[0]);
486
487	if (!nr_keys) {
488		free(collection.keys);
489		collection.keys = NULL;
490		return 0;
491	}
492
493	/*
494	 * All keys have uniq sndid/rcvid and sndid != rcvid in order to
495	 * check for any bugs/issues for different keyids, visible to both
496	 * peers. Keyid == 254 is unused.
497	 */
498	if (nr_keys > 127)
499		test_error("Test requires too many keys, correct the source");
500
501	collection.keys = reallocarray(collection.keys, nr_keys, key_sz);
502	if (!collection.keys)
503		return -ENOMEM;
504
505	memset(collection.keys, 0, nr_keys * key_sz);
506	collection.nr_keys = nr_keys;
507	while (nr_keys--)
508		init_key_in_collection(nr_keys, randomized);
509
510	return 0;
511}
512
513static void test_key_error(const char *msg, struct test_key *key)
514{
515	test_error("%s: key: { %s, %u:%u, %u, %u:%u:%u:%u:%u (%u)}",
516		   msg, key->alg, key->client_keyid, key->server_keyid,
517		   key->maclen, key->matches_client, key->matches_server,
518		   key->matches_vrf, key->is_current, key->is_rnext, key->len);
519}
520
521static int test_add_key_cr(int sk, const char *pwd, unsigned int pwd_len,
522			   union tcp_addr addr, uint8_t vrf,
523			   uint8_t sndid, uint8_t rcvid,
524			   uint8_t maclen, const char *alg,
525			   bool set_current, bool set_rnext)
526{
527	struct tcp_ao_add tmp = {};
528	uint8_t keyflags = 0;
529	int err;
530
531	if (!alg)
532		alg = DEFAULT_TEST_ALGO;
533
534	if (vrf)
535		keyflags |= TCP_AO_KEYF_IFINDEX;
536	err = test_prepare_key(&tmp, alg, addr, set_current, set_rnext,
537			       DEFAULT_TEST_PREFIX, vrf, sndid, rcvid, maclen,
538			       keyflags, pwd_len, pwd);
539	if (err)
540		return err;
541
542	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
543	if (err < 0)
544		return -errno;
545
546	return test_verify_socket_key(sk, &tmp);
547}
548
549static void verify_current_rnext(const char *tst, int sk,
550				 int current_keyid, int rnext_keyid)
551{
552	struct tcp_ao_info_opt ao_info = {};
553
554	if (test_get_ao_info(sk, &ao_info))
555		test_error("getsockopt(TCP_AO_INFO) failed");
556
557	errno = 0;
558	if (current_keyid >= 0) {
559		if (!ao_info.set_current)
560			test_fail("%s: the socket doesn't have current key", tst);
561		else if (ao_info.current_key != current_keyid)
562			test_fail("%s: current key is not the expected one %d != %u",
563				  tst, current_keyid, ao_info.current_key);
564		else
565			test_ok("%s: current key %u as expected",
566				tst, ao_info.current_key);
567	}
568	if (rnext_keyid >= 0) {
569		if (!ao_info.set_rnext)
570			test_fail("%s: the socket doesn't have rnext key", tst);
571		else if (ao_info.rnext != rnext_keyid)
572			test_fail("%s: rnext key is not the expected one %d != %u",
573				  tst, rnext_keyid, ao_info.rnext);
574		else
575			test_ok("%s: rnext key %u as expected", tst, ao_info.rnext);
576	}
577}
578
579
580static int key_collection_socket(bool server, unsigned int port)
581{
582	unsigned int i;
583	int sk;
584
585	if (server)
586		sk = test_listen_socket(this_ip_addr, port, 1);
587	else
588		sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
589	if (sk < 0)
590		test_error("socket()");
591
592	for (i = 0; i < collection.nr_keys; i++) {
593		struct test_key *key = &collection.keys[i];
594		union tcp_addr *addr = &wrong_addr;
595		uint8_t sndid, rcvid, vrf;
596		bool set_current = false, set_rnext = false;
597
598		if (key->matches_vrf)
599			vrf = 0;
600		else
601			vrf = test_vrf_ifindex;
602		if (server) {
603			if (key->matches_client)
604				addr = &this_ip_dest;
605			sndid = key->server_keyid;
606			rcvid = key->client_keyid;
607		} else {
608			if (key->matches_server)
609				addr = &this_ip_dest;
610			sndid = key->client_keyid;
611			rcvid = key->server_keyid;
612			key->used_on_client_tx = set_current = key->is_current;
613			key->used_on_server_tx = set_rnext = key->is_rnext;
614		}
615
616		if (test_add_key_cr(sk, key->password, key->len,
617				    *addr, vrf, sndid, rcvid, key->maclen,
618				    key->alg, set_current, set_rnext))
619			test_key_error("setsockopt(TCP_AO_ADD_KEY)", key);
620#ifdef DEBUG
621		test_print("%s [%u/%u] key: { %s, %u:%u, %u, %u:%u:%u:%u (%u)}",
622			   server ? "server" : "client", i, collection.nr_keys,
623			   key->alg, rcvid, sndid, key->maclen,
624			   key->matches_client, key->matches_server,
625			   key->is_current, key->is_rnext, key->len);
626#endif
627	}
628	return sk;
629}
630
631static void verify_counters(const char *tst_name, bool is_listen_sk, bool server,
632			    struct tcp_ao_counters *a, struct tcp_ao_counters *b)
633{
634	unsigned int i;
635
636	__test_tcp_ao_counters_cmp(tst_name, a, b, TEST_CNT_GOOD);
637
638	for (i = 0; i < collection.nr_keys; i++) {
639		struct test_key *key = &collection.keys[i];
640		uint8_t sndid, rcvid;
641		bool rx_cnt_expected;
642
643		if (key->skip_counters_checks)
644			continue;
645		if (server) {
646			sndid = key->server_keyid;
647			rcvid = key->client_keyid;
648			rx_cnt_expected = key->used_on_client_tx;
649		} else {
650			sndid = key->client_keyid;
651			rcvid = key->server_keyid;
652			rx_cnt_expected = key->used_on_server_tx;
653		}
654
655		test_tcp_ao_key_counters_cmp(tst_name, a, b,
656					     rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0,
657					     sndid, rcvid);
658	}
659	test_tcp_ao_counters_free(a);
660	test_tcp_ao_counters_free(b);
661	test_ok("%s: passed counters checks", tst_name);
662}
663
664static struct tcp_ao_getsockopt *lookup_key(struct tcp_ao_getsockopt *buf,
665					    size_t len, int sndid, int rcvid)
666{
667	size_t i;
668
669	for (i = 0; i < len; i++) {
670		if (sndid >= 0 && buf[i].sndid != sndid)
671			continue;
672		if (rcvid >= 0 && buf[i].rcvid != rcvid)
673			continue;
674		return &buf[i];
675	}
676	return NULL;
677}
678
679static void verify_keys(const char *tst_name, int sk,
680			bool is_listen_sk, bool server)
681{
682	socklen_t len = sizeof(struct tcp_ao_getsockopt);
683	struct tcp_ao_getsockopt *keys;
684	bool passed_test = true;
685	unsigned int i;
686
687	keys = calloc(collection.nr_keys, len);
688	if (!keys)
689		test_error("calloc()");
690
691	keys->nkeys = collection.nr_keys;
692	keys->get_all = 1;
693
694	if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, keys, &len)) {
695		free(keys);
696		test_error("getsockopt(TCP_AO_GET_KEYS)");
697	}
698
699	for (i = 0; i < collection.nr_keys; i++) {
700		struct test_key *key = &collection.keys[i];
701		struct tcp_ao_getsockopt *dump_key;
702		bool is_kdf_aes_128_cmac = false;
703		bool is_cmac_aes = false;
704		uint8_t sndid, rcvid;
705		bool matches = false;
706
707		if (server) {
708			if (key->matches_client)
709				matches = true;
710			sndid = key->server_keyid;
711			rcvid = key->client_keyid;
712		} else {
713			if (key->matches_server)
714				matches = true;
715			sndid = key->client_keyid;
716			rcvid = key->server_keyid;
717		}
718		if (!key->matches_vrf)
719			matches = false;
720		/* no keys get removed on the original listener socket */
721		if (is_listen_sk)
722			matches = true;
723
724		dump_key = lookup_key(keys, keys->nkeys, sndid, rcvid);
725		if (matches != !!dump_key) {
726			test_fail("%s: key %u:%u %s%s on the socket",
727				  tst_name, sndid, rcvid,
728				  key->matches_vrf ? "" : "[vrf] ",
729				  matches ? "disappeared" : "yet present");
730			passed_test = false;
731			goto out;
732		}
733		if (!dump_key)
734			continue;
735
736		if (!strcmp("cmac(aes128)", key->alg)) {
737			is_kdf_aes_128_cmac = (key->len != 16);
738			is_cmac_aes = true;
739		}
740
741		if (is_cmac_aes) {
742			if (strcmp(dump_key->alg_name, "cmac(aes)")) {
743				test_fail("%s: key %u:%u cmac(aes) has unexpected alg %s",
744					  tst_name, sndid, rcvid,
745					  dump_key->alg_name);
746				passed_test = false;
747				continue;
748			}
749		} else if (strcmp(dump_key->alg_name, key->alg)) {
750			test_fail("%s: key %u:%u has unexpected alg %s != %s",
751				  tst_name, sndid, rcvid,
752				  dump_key->alg_name, key->alg);
753			passed_test = false;
754			continue;
755		}
756		if (is_kdf_aes_128_cmac) {
757			if (dump_key->keylen != 16) {
758				test_fail("%s: key %u:%u cmac(aes128) has unexpected len %u",
759					  tst_name, sndid, rcvid,
760					  dump_key->keylen);
761				continue;
762			}
763		} else if (dump_key->keylen != key->len) {
764			test_fail("%s: key %u:%u changed password len %u != %u",
765				  tst_name, sndid, rcvid,
766				  dump_key->keylen, key->len);
767			passed_test = false;
768			continue;
769		}
770		if (!is_kdf_aes_128_cmac &&
771		    memcmp(dump_key->key, key->password, key->len)) {
772			test_fail("%s: key %u:%u has different password",
773				  tst_name, sndid, rcvid);
774			passed_test = false;
775			continue;
776		}
777		if (dump_key->maclen != key->maclen) {
778			test_fail("%s: key %u:%u changed maclen %u != %u",
779				  tst_name, sndid, rcvid,
780				  dump_key->maclen, key->maclen);
781			passed_test = false;
782			continue;
783		}
784	}
785
786	if (passed_test)
787		test_ok("%s: The socket keys are consistent with the expectations",
788			tst_name);
789out:
790	free(keys);
791}
792
793static int start_server(const char *tst_name, unsigned int port, size_t quota,
794			struct tcp_ao_counters *begin,
795			unsigned int current_index, unsigned int rnext_index)
796{
797	struct tcp_ao_counters lsk_c1, lsk_c2;
798	ssize_t bytes;
799	int sk, lsk;
800
801	synchronize_threads(); /* 1: key collection initialized */
802	lsk = key_collection_socket(true, port);
803	if (test_get_tcp_ao_counters(lsk, &lsk_c1))
804		test_error("test_get_tcp_ao_counters()");
805	synchronize_threads(); /* 2: MKTs added => connect() */
806	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
807		test_error("test_wait_fd()");
808
809	sk = accept(lsk, NULL, NULL);
810	if (sk < 0)
811		test_error("accept()");
812	if (test_get_tcp_ao_counters(sk, begin))
813		test_error("test_get_tcp_ao_counters()");
814
815	synchronize_threads(); /* 3: accepted => send data */
816	if (test_get_tcp_ao_counters(lsk, &lsk_c2))
817		test_error("test_get_tcp_ao_counters()");
818	verify_keys(tst_name, lsk, true, true);
819	close(lsk);
820
821	bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
822	if (bytes != quota)
823		test_fail("%s: server served: %zd", tst_name, bytes);
824	else
825		test_ok("%s: server alive", tst_name);
826
827	verify_counters(tst_name, true, true, &lsk_c1, &lsk_c2);
828
829	return sk;
830}
831
832static void end_server(const char *tst_name, int sk,
833		       struct tcp_ao_counters *begin)
834{
835	struct tcp_ao_counters end;
836
837	if (test_get_tcp_ao_counters(sk, &end))
838		test_error("test_get_tcp_ao_counters()");
839	verify_keys(tst_name, sk, false, true);
840
841	synchronize_threads(); /* 4: verified => closed */
842	close(sk);
843
844	verify_counters(tst_name, false, true, begin, &end);
845	synchronize_threads(); /* 5: counters */
846}
847
848static void try_server_run(const char *tst_name, unsigned int port, size_t quota,
849			   unsigned int current_index, unsigned int rnext_index)
850{
851	struct tcp_ao_counters tmp;
852	int sk;
853
854	sk = start_server(tst_name, port, quota, &tmp,
855			  current_index, rnext_index);
856	end_server(tst_name, sk, &tmp);
857}
858
859static void server_rotations(const char *tst_name, unsigned int port,
860			     size_t quota, unsigned int rotations,
861			     unsigned int current_index, unsigned int rnext_index)
862{
863	struct tcp_ao_counters tmp;
864	unsigned int i;
865	int sk;
866
867	sk = start_server(tst_name, port, quota, &tmp,
868			  current_index, rnext_index);
869
870	for (i = current_index + 1; rotations > 0; i++, rotations--) {
871		ssize_t bytes;
872
873		if (i >= collection.nr_keys)
874			i = 0;
875		bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
876		if (bytes != quota) {
877			test_fail("%s: server served: %zd", tst_name, bytes);
878			return;
879		}
880		verify_current_rnext(tst_name, sk,
881				     collection.keys[i].server_keyid, -1);
882		synchronize_threads(); /* verify current/rnext */
883	}
884	end_server(tst_name, sk, &tmp);
885}
886
887static int run_client(const char *tst_name, unsigned int port,
888		      unsigned int nr_keys, int current_index, int rnext_index,
889		      struct tcp_ao_counters *before,
890		      const size_t msg_sz, const size_t msg_nr)
891{
892	int sk;
893
894	synchronize_threads(); /* 1: key collection initialized */
895	sk = key_collection_socket(false, port);
896
897	if (current_index >= 0 || rnext_index >= 0) {
898		int sndid = -1, rcvid = -1;
899
900		if (current_index >= 0)
901			sndid = collection.keys[current_index].client_keyid;
902		if (rnext_index >= 0)
903			rcvid = collection.keys[rnext_index].server_keyid;
904		if (test_set_key(sk, sndid, rcvid))
905			test_error("failed to set current/rnext keys");
906	}
907	if (before && test_get_tcp_ao_counters(sk, before))
908		test_error("test_get_tcp_ao_counters()");
909
910	synchronize_threads(); /* 2: MKTs added => connect() */
911	if (test_connect_socket(sk, this_ip_dest, port++) <= 0)
912		test_error("failed to connect()");
913	if (current_index < 0)
914		current_index = nr_keys - 1;
915	if (rnext_index < 0)
916		rnext_index = nr_keys - 1;
917	collection.keys[current_index].used_on_client_tx = 1;
918	collection.keys[rnext_index].used_on_server_tx = 1;
919
920	synchronize_threads(); /* 3: accepted => send data */
921	if (test_client_verify(sk, msg_sz, msg_nr, TEST_TIMEOUT_SEC)) {
922		test_fail("verify failed");
923		close(sk);
924		if (before)
925			test_tcp_ao_counters_free(before);
926		return -1;
927	}
928
929	return sk;
930}
931
932static int start_client(const char *tst_name, unsigned int port,
933			unsigned int nr_keys, int current_index, int rnext_index,
934			struct tcp_ao_counters *before,
935			const size_t msg_sz, const size_t msg_nr)
936{
937	if (init_default_key_collection(nr_keys, true))
938		test_error("Failed to init the key collection");
939
940	return run_client(tst_name, port, nr_keys, current_index,
941			  rnext_index, before, msg_sz, msg_nr);
942}
943
944static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
945		       int current_index, int rnext_index,
946		       struct tcp_ao_counters *start)
947{
948	struct tcp_ao_counters end;
949
950	/* Some application may become dependent on this kernel choice */
951	if (current_index < 0)
952		current_index = nr_keys - 1;
953	if (rnext_index < 0)
954		rnext_index = nr_keys - 1;
955	verify_current_rnext(tst_name, sk,
956			     collection.keys[current_index].client_keyid,
957			     collection.keys[rnext_index].server_keyid);
958	if (start && test_get_tcp_ao_counters(sk, &end))
959		test_error("test_get_tcp_ao_counters()");
960	verify_keys(tst_name, sk, false, false);
961	synchronize_threads(); /* 4: verify => closed */
962	close(sk);
963	if (start)
964		verify_counters(tst_name, false, false, start, &end);
965	synchronize_threads(); /* 5: counters */
966}
967
968static void try_unmatched_keys(int sk, int *rnext_index)
969{
970	struct test_key *key;
971	unsigned int i = 0;
972	int err;
973
974	do {
975		key = &collection.keys[i];
976		if (!key->matches_server)
977			break;
978	} while (++i < collection.nr_keys);
979	if (key->matches_server)
980		test_error("all keys on client match the server");
981
982	err = test_add_key_cr(sk, key->password, key->len, wrong_addr,
983			      0, key->client_keyid, key->server_keyid,
984			      key->maclen, key->alg, 0, 0);
985	if (!err) {
986		test_fail("Added a key with non-matching ip-address for established sk");
987		return;
988	}
989	if (err == -EINVAL)
990		test_ok("Can't add a key with non-matching ip-address for established sk");
991	else
992		test_error("Failed to add a key");
993
994	err = test_add_key_cr(sk, key->password, key->len, this_ip_dest,
995			      test_vrf_ifindex,
996			      key->client_keyid, key->server_keyid,
997			      key->maclen, key->alg, 0, 0);
998	if (!err) {
999		test_fail("Added a key with non-matching VRF for established sk");
1000		return;
1001	}
1002	if (err == -EINVAL)
1003		test_ok("Can't add a key with non-matching VRF for established sk");
1004	else
1005		test_error("Failed to add a key");
1006
1007	for (i = 0; i < collection.nr_keys; i++) {
1008		key = &collection.keys[i];
1009		if (!key->matches_client)
1010			break;
1011	}
1012	if (key->matches_client)
1013		test_error("all keys on server match the client");
1014	if (test_set_key(sk, -1, key->server_keyid))
1015		test_error("Can't change the current key");
1016	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1017		test_fail("verify failed");
1018	*rnext_index = i;
1019}
1020
1021static int client_non_matching(const char *tst_name, unsigned int port,
1022			       unsigned int nr_keys,
1023			       int current_index, int rnext_index,
1024			       const size_t msg_sz, const size_t msg_nr)
1025{
1026	unsigned int i;
1027
1028	if (init_default_key_collection(nr_keys, true))
1029		test_error("Failed to init the key collection");
1030
1031	for (i = 0; i < nr_keys; i++) {
1032		/* key (0, 0) matches */
1033		collection.keys[i].matches_client = !!((i + 3) % 4);
1034		collection.keys[i].matches_server = !!((i + 2) % 4);
1035		if (kernel_config_has(KCONFIG_NET_VRF))
1036			collection.keys[i].matches_vrf = !!((i + 1) % 4);
1037	}
1038
1039	return run_client(tst_name, port, nr_keys, current_index,
1040			  rnext_index, NULL, msg_sz, msg_nr);
1041}
1042
1043static void check_current_back(const char *tst_name, unsigned int port,
1044			       unsigned int nr_keys,
1045			       unsigned int current_index, unsigned int rnext_index,
1046			       unsigned int rotate_to_index)
1047{
1048	struct tcp_ao_counters tmp;
1049	int sk;
1050
1051	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1052			  &tmp, msg_len, nr_packets);
1053	if (sk < 0)
1054		return;
1055	if (test_set_key(sk, collection.keys[rotate_to_index].client_keyid, -1))
1056		test_error("Can't change the current key");
1057	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1058		test_fail("verify failed");
1059	/* There is a race here: between setting the current_key with
1060	 * setsockopt(TCP_AO_INFO) and starting to send some data - there
1061	 * might have been a segment received with the desired
1062	 * RNext_key set. In turn that would mean that the first outgoing
1063	 * segment will have the desired current_key (flipped back).
1064	 * Which is what the user/test wants. As it's racy, skip checking
1065	 * the counters, yet check what are the resulting current/rnext
1066	 * keys on both sides.
1067	 */
1068	collection.keys[rotate_to_index].skip_counters_checks = 1;
1069
1070	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1071}
1072
1073static void roll_over_keys(const char *tst_name, unsigned int port,
1074			   unsigned int nr_keys, unsigned int rotations,
1075			   unsigned int current_index, unsigned int rnext_index)
1076{
1077	struct tcp_ao_counters tmp;
1078	unsigned int i;
1079	int sk;
1080
1081	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1082			  &tmp, msg_len, nr_packets);
1083	if (sk < 0)
1084		return;
1085	for (i = rnext_index + 1; rotations > 0; i++, rotations--) {
1086		if (i >= collection.nr_keys)
1087			i = 0;
1088		if (test_set_key(sk, -1, collection.keys[i].server_keyid))
1089			test_error("Can't change the Rnext key");
1090		if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
1091			test_fail("verify failed");
1092			close(sk);
1093			test_tcp_ao_counters_free(&tmp);
1094			return;
1095		}
1096		verify_current_rnext(tst_name, sk, -1,
1097				     collection.keys[i].server_keyid);
1098		collection.keys[i].used_on_server_tx = 1;
1099		synchronize_threads(); /* verify current/rnext */
1100	}
1101	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1102}
1103
1104static void try_client_run(const char *tst_name, unsigned int port,
1105			   unsigned int nr_keys, int current_index, int rnext_index)
1106{
1107	struct tcp_ao_counters tmp;
1108	int sk;
1109
1110	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1111			  &tmp, msg_len, nr_packets);
1112	if (sk < 0)
1113		return;
1114	end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1115}
1116
1117static void try_client_match(const char *tst_name, unsigned int port,
1118			     unsigned int nr_keys,
1119			     int current_index, int rnext_index)
1120{
1121	int sk;
1122
1123	sk = client_non_matching(tst_name, port, nr_keys, current_index,
1124				 rnext_index, msg_len, nr_packets);
1125	if (sk < 0)
1126		return;
1127	try_unmatched_keys(sk, &rnext_index);
1128	end_client(tst_name, sk, nr_keys, current_index, rnext_index, NULL);
1129}
1130
1131static void *server_fn(void *arg)
1132{
1133	unsigned int port = test_server_port;
1134
1135	setup_vrfs();
1136	try_server_run("server: Check current/rnext keys unset before connect()",
1137		       port++, quota, 19, 19);
1138	try_server_run("server: Check current/rnext keys set before connect()",
1139		       port++, quota, 10, 10);
1140	try_server_run("server: Check current != rnext keys set before connect()",
1141		       port++, quota, 5, 10);
1142	try_server_run("server: Check current flapping back on peer's RnextKey request",
1143		       port++, quota * 2, 5, 10);
1144	server_rotations("server: Rotate over all different keys", port++,
1145			 quota, 20, 0, 0);
1146	try_server_run("server: Check accept() => established key matching",
1147		       port++, quota * 2, 0, 0);
1148
1149	synchronize_threads(); /* don't race to exit: client exits */
1150	return NULL;
1151}
1152
1153static void check_established_socket(void)
1154{
1155	unsigned int port = test_server_port;
1156
1157	setup_vrfs();
1158	try_client_run("client: Check current/rnext keys unset before connect()",
1159		       port++, 20, -1, -1);
1160	try_client_run("client: Check current/rnext keys set before connect()",
1161		       port++, 20, 10, 10);
1162	try_client_run("client: Check current != rnext keys set before connect()",
1163		       port++, 20, 10, 5);
1164	check_current_back("client: Check current flapping back on peer's RnextKey request",
1165			   port++, 20, 10, 5, 2);
1166	roll_over_keys("client: Rotate over all different keys", port++,
1167		       20, 20, 0, 0);
1168	try_client_match("client: Check connect() => established key matching",
1169			 port++, 20, 0, 0);
1170}
1171
1172static void *client_fn(void *arg)
1173{
1174	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
1175		test_error("Can't convert ip address %s", TEST_WRONG_IP);
1176	check_closed_socket();
1177	check_listen_socket();
1178	check_established_socket();
1179	return NULL;
1180}
1181
1182int main(int argc, char *argv[])
1183{
1184	test_init(120, server_fn, client_fn);
1185	return 0;
1186}
1187