1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5 
6 #define fault(type)	(inj == FAULT_ ## type)
7 
test_add_key_maclen(int sk,const char * key,uint8_t maclen,union tcp_addr in_addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid)8 static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
9 				      union tcp_addr in_addr, uint8_t prefix,
10 				      uint8_t sndid, uint8_t rcvid)
11 {
12 	struct tcp_ao_add tmp = {};
13 	int err;
14 
15 	if (prefix > DEFAULT_TEST_PREFIX)
16 		prefix = DEFAULT_TEST_PREFIX;
17 
18 	err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
19 			       prefix, 0, sndid, rcvid, maclen,
20 			       0, strlen(key), key);
21 	if (err)
22 		return err;
23 
24 	err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
25 	if (err < 0)
26 		return -errno;
27 
28 	return test_verify_socket_key(sk, &tmp);
29 }
30 
try_accept(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,uint8_t maclen,const char * cnt_name,test_cnt cnt_expected,fault_t inj)31 static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
32 		       union tcp_addr addr, uint8_t prefix,
33 		       uint8_t sndid, uint8_t rcvid, uint8_t maclen,
34 		       const char *cnt_name, test_cnt cnt_expected,
35 		       fault_t inj)
36 {
37 	struct tcp_ao_counters ao_cnt1, ao_cnt2;
38 	uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
39 	int lsk, err, sk = 0;
40 	time_t timeout;
41 
42 	lsk = test_listen_socket(this_ip_addr, port, 1);
43 
44 	if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
45 		test_error("setsockopt(TCP_AO_ADD_KEY)");
46 
47 	if (cnt_name)
48 		before_cnt = netstat_get_one(cnt_name, NULL);
49 	if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt1))
50 		test_error("test_get_tcp_ao_counters()");
51 
52 	synchronize_threads(); /* preparations done */
53 
54 	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
55 	err = test_wait_fd(lsk, timeout, 0);
56 	if (err == -ETIMEDOUT) {
57 		if (!fault(TIMEOUT))
58 			test_fail("timed out for accept()");
59 	} else if (err < 0) {
60 		test_error("test_wait_fd()");
61 	} else {
62 		if (fault(TIMEOUT))
63 			test_fail("ready to accept");
64 
65 		sk = accept(lsk, NULL, NULL);
66 		if (sk < 0) {
67 			test_error("accept()");
68 		} else {
69 			if (fault(TIMEOUT))
70 				test_fail("%s: accepted", tst_name);
71 		}
72 	}
73 
74 	synchronize_threads(); /* before counter checks */
75 	if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt2))
76 		test_error("test_get_tcp_ao_counters()");
77 
78 	close(lsk);
79 
80 	if (pwd)
81 		test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
82 
83 	if (!cnt_name)
84 		goto out;
85 
86 	after_cnt = netstat_get_one(cnt_name, NULL);
87 
88 	if (after_cnt <= before_cnt) {
89 		test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64,
90 				tst_name, cnt_name, after_cnt, before_cnt);
91 	} else {
92 		test_ok("%s: counter %s increased %" PRIu64  " => %" PRIu64,
93 			tst_name, cnt_name, before_cnt, after_cnt);
94 	}
95 
96 out:
97 	synchronize_threads(); /* close() */
98 	if (sk > 0)
99 		close(sk);
100 }
101 
server_fn(void * arg)102 static void *server_fn(void *arg)
103 {
104 	union tcp_addr wrong_addr, network_addr;
105 	unsigned int port = test_server_port;
106 
107 	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
108 		test_error("Can't convert ip address %s", TEST_WRONG_IP);
109 
110 	try_accept("Non-AO server + AO client", port++, NULL,
111 		   this_ip_dest, -1, 100, 100, 0,
112 		   "TCPAOKeyNotFound", 0, FAULT_TIMEOUT);
113 
114 	try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
115 		   this_ip_dest, -1, 100, 100, 0,
116 		   "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);
117 
118 	try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
119 		   this_ip_dest, -1, 100, 100, 0,
120 		   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
121 
122 	try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
123 		   this_ip_dest, -1, 100, 101, 0,
124 		   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
125 
126 	try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
127 		   this_ip_dest, -1, 101, 100, 0,
128 		   "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);
129 
130 	try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
131 		   this_ip_dest, -1, 100, 100, 8,
132 		   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
133 
134 	try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
135 		   wrong_addr, -1, 100, 100, 0,
136 		   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
137 
138 	try_accept("Client: Wrong addr", port++, NULL,
139 		   this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_TIMEOUT);
140 
141 	try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
142 		   this_ip_dest, -1, 200, 100, 0,
143 		   "TCPAOGood", TEST_CNT_GOOD, 0);
144 
145 	if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
146 		test_error("Can't convert ip address %s", TEST_NETWORK);
147 
148 	try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
149 		   network_addr, 16, 100, 100, 0,
150 		   "TCPAOGood", TEST_CNT_GOOD, 0);
151 
152 	try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
153 		   this_ip_dest, -1, 100, 100, 0,
154 		   "TCPAOGood", TEST_CNT_GOOD, 0);
155 
156 	/* client exits */
157 	synchronize_threads();
158 	return NULL;
159 }
160 
try_connect(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,test_cnt cnt_expected,fault_t inj)161 static void try_connect(const char *tst_name, unsigned int port,
162 			const char *pwd, union tcp_addr addr, uint8_t prefix,
163 			uint8_t sndid, uint8_t rcvid,
164 			test_cnt cnt_expected, fault_t inj)
165 {
166 	struct tcp_ao_counters ao_cnt1, ao_cnt2;
167 	time_t timeout;
168 	int sk, ret;
169 
170 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
171 	if (sk < 0)
172 		test_error("socket()");
173 
174 	if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
175 		test_error("setsockopt(TCP_AO_ADD_KEY)");
176 
177 	if (pwd && test_get_tcp_ao_counters(sk, &ao_cnt1))
178 		test_error("test_get_tcp_ao_counters()");
179 
180 	synchronize_threads(); /* preparations done */
181 
182 	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
183 	ret = _test_connect_socket(sk, this_ip_dest, port, timeout);
184 
185 	synchronize_threads(); /* before counter checks */
186 	if (ret < 0) {
187 		if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
188 			test_ok("%s: connect() was prevented", tst_name);
189 		} else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
190 			test_ok("%s", tst_name);
191 		} else if (ret == -ECONNREFUSED &&
192 				(fault(TIMEOUT) || fault(KEYREJECT))) {
193 			test_ok("%s: refused to connect", tst_name);
194 		} else {
195 			test_error("%s: connect() returned %d", tst_name, ret);
196 		}
197 		goto out;
198 	}
199 
200 	if (fault(TIMEOUT) || fault(KEYREJECT))
201 		test_fail("%s: connected", tst_name);
202 	else
203 		test_ok("%s: connected", tst_name);
204 	if (pwd && ret > 0) {
205 		if (test_get_tcp_ao_counters(sk, &ao_cnt2))
206 			test_error("test_get_tcp_ao_counters()");
207 		test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
208 	}
209 out:
210 	synchronize_threads(); /* close() */
211 
212 	if (ret > 0)
213 		close(sk);
214 }
215 
client_fn(void * arg)216 static void *client_fn(void *arg)
217 {
218 	union tcp_addr wrong_addr, network_addr, addr_any = {};
219 	unsigned int port = test_server_port;
220 
221 	if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
222 		test_error("Can't convert ip address %s", TEST_WRONG_IP);
223 
224 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
225 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
226 	try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
227 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
228 
229 	trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest,
230 				-1, port, 0, 0, 1, 0, 0, 0);
231 	try_connect("AO server + Non-AO client", port++, NULL,
232 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
233 
234 	trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest,
235 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
236 	try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
237 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
238 
239 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
240 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
241 	try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
242 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
243 
244 	trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
245 				 port, 0, 100, 100);
246 	try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
247 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
248 
249 	trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest,
250 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
251 	try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
252 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
253 
254 	trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
255 			      -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
256 	try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
257 			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
258 
259 	try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
260 			wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);
261 
262 	try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
263 			this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);
264 
265 	if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
266 		test_error("Can't convert ip address %s", TEST_NETWORK);
267 
268 	try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
269 			this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);
270 
271 	try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
272 			network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);
273 
274 	return NULL;
275 }
276 
main(int argc,char * argv[])277 int main(int argc, char *argv[])
278 {
279 	test_init(22, server_fn, client_fn);
280 	return 0;
281 }
282