1// SPDX-License-Identifier: GPL-2.0
2/* Author: Dmitry Safonov <dima@arista.com> */
3#include <inttypes.h>
4#include "aolib.h"
5
6static union tcp_addr local_addr;
7
8static void __setup_lo_intf(const char *lo_intf,
9			    const char *addr_str, uint8_t prefix)
10{
11	if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1)
12		test_error("Can't convert local ip address");
13
14	if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix))
15		test_error("Failed to add %s ip address", lo_intf);
16
17	if (link_set_up(lo_intf))
18		test_error("Failed to bring %s up", lo_intf);
19}
20
21static void setup_lo_intf(const char *lo_intf)
22{
23#ifdef IPV6_TEST
24	__setup_lo_intf(lo_intf, "::1", 128);
25#else
26	__setup_lo_intf(lo_intf, "127.0.0.1", 8);
27#endif
28}
29
30static void tcp_self_connect(const char *tst, unsigned int port,
31			     bool different_keyids, bool check_restore)
32{
33	uint64_t before_challenge_ack, after_challenge_ack;
34	uint64_t before_syn_challenge, after_syn_challenge;
35	struct tcp_ao_counters before_ao, after_ao;
36	uint64_t before_aogood, after_aogood;
37	struct netstat *ns_before, *ns_after;
38	const size_t nr_packets = 20;
39	struct tcp_ao_repair ao_img;
40	struct tcp_sock_state img;
41	sockaddr_af addr;
42	int sk;
43
44	tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port));
45
46	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
47	if (sk < 0)
48		test_error("socket()");
49
50	if (different_keyids) {
51		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7))
52			test_error("setsockopt(TCP_AO_ADD_KEY)");
53		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5))
54			test_error("setsockopt(TCP_AO_ADD_KEY)");
55	} else {
56		if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100))
57			test_error("setsockopt(TCP_AO_ADD_KEY)");
58	}
59
60	if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0)
61		test_error("bind()");
62
63	ns_before = netstat_read();
64	before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
65	before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL);
66	before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL);
67	if (test_get_tcp_ao_counters(sk, &before_ao))
68		test_error("test_get_tcp_ao_counters()");
69
70	if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr,
71				  sizeof(addr), TEST_TIMEOUT_SEC) < 0) {
72		ns_after = netstat_read();
73		netstat_print_diff(ns_before, ns_after);
74		test_error("failed to connect()");
75	}
76
77	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
78		test_fail("%s: tcp connection verify failed", tst);
79		close(sk);
80		return;
81	}
82
83	ns_after = netstat_read();
84	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
85	after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL);
86	after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL);
87	if (test_get_tcp_ao_counters(sk, &after_ao))
88		test_error("test_get_tcp_ao_counters()");
89	if (!check_restore) {
90		/* to debug: netstat_print_diff(ns_before, ns_after); */
91		netstat_free(ns_before);
92	}
93	netstat_free(ns_after);
94
95	if (after_aogood <= before_aogood) {
96		test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
97			  tst, after_aogood, before_aogood);
98		close(sk);
99		return;
100	}
101	if (after_challenge_ack <= before_challenge_ack ||
102	    after_syn_challenge <= before_syn_challenge) {
103		/*
104		 * It's also meant to test simultaneous open, so check
105		 * these counters as well.
106		 */
107		test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu",
108			  tst, after_challenge_ack, before_challenge_ack,
109			  after_syn_challenge, before_syn_challenge);
110		close(sk);
111		return;
112	}
113
114	if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) {
115		close(sk);
116		return;
117	}
118
119	if (!check_restore) {
120		test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
121				tst, before_aogood, after_aogood);
122		close(sk);
123		return;
124	}
125
126	test_enable_repair(sk);
127	test_sock_checkpoint(sk, &img, &addr);
128#ifdef IPV6_TEST
129	addr.sin6_port = htons(port + 1);
130#else
131	addr.sin_port = htons(port + 1);
132#endif
133	test_ao_checkpoint(sk, &ao_img);
134	test_kill_sk(sk);
135
136	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
137	if (sk < 0)
138		test_error("socket()");
139
140	test_enable_repair(sk);
141	__test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr));
142	if (different_keyids) {
143		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
144					  local_addr, -1, 7, 5))
145			test_error("setsockopt(TCP_AO_ADD_KEY)");
146		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
147					  local_addr, -1, 5, 7))
148			test_error("setsockopt(TCP_AO_ADD_KEY)");
149	} else {
150		if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
151					  local_addr, -1, 100, 100))
152			test_error("setsockopt(TCP_AO_ADD_KEY)");
153	}
154	test_ao_restore(sk, &ao_img);
155	test_disable_repair(sk);
156	test_sock_state_free(&img);
157	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
158		test_fail("%s: tcp connection verify failed", tst);
159		close(sk);
160		return;
161	}
162	ns_after = netstat_read();
163	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
164	/* to debug: netstat_print_diff(ns_before, ns_after); */
165	netstat_free(ns_before);
166	netstat_free(ns_after);
167	close(sk);
168	if (after_aogood <= before_aogood) {
169		test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
170			  tst, after_aogood, before_aogood);
171		return;
172	}
173	test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
174			tst, before_aogood, after_aogood);
175}
176
177static void *client_fn(void *arg)
178{
179	unsigned int port = test_server_port;
180
181	setup_lo_intf("lo");
182
183	tcp_self_connect("self-connect(same keyids)", port++, false, false);
184	tcp_self_connect("self-connect(different keyids)", port++, true, false);
185	tcp_self_connect("self-connect(restore)", port, false, true);
186	port += 2;
187	tcp_self_connect("self-connect(restore, different keyids)", port, true, true);
188	port += 2;
189
190	return NULL;
191}
192
193int main(int argc, char *argv[])
194{
195	test_init(4, client_fn, NULL);
196	return 0;
197}
198