1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 #include <inttypes.h>
4 #include "aolib.h"
5
6 #define fault(type) (inj == FAULT_ ## type)
7
test_add_key_maclen(int sk,const char * key,uint8_t maclen,union tcp_addr in_addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid)8 static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
9 union tcp_addr in_addr, uint8_t prefix,
10 uint8_t sndid, uint8_t rcvid)
11 {
12 struct tcp_ao_add tmp = {};
13 int err;
14
15 if (prefix > DEFAULT_TEST_PREFIX)
16 prefix = DEFAULT_TEST_PREFIX;
17
18 err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
19 prefix, 0, sndid, rcvid, maclen,
20 0, strlen(key), key);
21 if (err)
22 return err;
23
24 err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
25 if (err < 0)
26 return -errno;
27
28 return test_verify_socket_key(sk, &tmp);
29 }
30
try_accept(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,uint8_t maclen,const char * cnt_name,test_cnt cnt_expected,fault_t inj)31 static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
32 union tcp_addr addr, uint8_t prefix,
33 uint8_t sndid, uint8_t rcvid, uint8_t maclen,
34 const char *cnt_name, test_cnt cnt_expected,
35 fault_t inj)
36 {
37 struct tcp_ao_counters ao_cnt1, ao_cnt2;
38 uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
39 int lsk, err, sk = 0;
40 time_t timeout;
41
42 lsk = test_listen_socket(this_ip_addr, port, 1);
43
44 if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
45 test_error("setsockopt(TCP_AO_ADD_KEY)");
46
47 if (cnt_name)
48 before_cnt = netstat_get_one(cnt_name, NULL);
49 if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt1))
50 test_error("test_get_tcp_ao_counters()");
51
52 synchronize_threads(); /* preparations done */
53
54 timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
55 err = test_wait_fd(lsk, timeout, 0);
56 if (err == -ETIMEDOUT) {
57 if (!fault(TIMEOUT))
58 test_fail("timed out for accept()");
59 } else if (err < 0) {
60 test_error("test_wait_fd()");
61 } else {
62 if (fault(TIMEOUT))
63 test_fail("ready to accept");
64
65 sk = accept(lsk, NULL, NULL);
66 if (sk < 0) {
67 test_error("accept()");
68 } else {
69 if (fault(TIMEOUT))
70 test_fail("%s: accepted", tst_name);
71 }
72 }
73
74 synchronize_threads(); /* before counter checks */
75 if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt2))
76 test_error("test_get_tcp_ao_counters()");
77
78 close(lsk);
79
80 if (pwd)
81 test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
82
83 if (!cnt_name)
84 goto out;
85
86 after_cnt = netstat_get_one(cnt_name, NULL);
87
88 if (after_cnt <= before_cnt) {
89 test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64,
90 tst_name, cnt_name, after_cnt, before_cnt);
91 } else {
92 test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64,
93 tst_name, cnt_name, before_cnt, after_cnt);
94 }
95
96 out:
97 synchronize_threads(); /* close() */
98 if (sk > 0)
99 close(sk);
100 }
101
server_fn(void * arg)102 static void *server_fn(void *arg)
103 {
104 union tcp_addr wrong_addr, network_addr;
105 unsigned int port = test_server_port;
106
107 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
108 test_error("Can't convert ip address %s", TEST_WRONG_IP);
109
110 try_accept("Non-AO server + AO client", port++, NULL,
111 this_ip_dest, -1, 100, 100, 0,
112 "TCPAOKeyNotFound", 0, FAULT_TIMEOUT);
113
114 try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
115 this_ip_dest, -1, 100, 100, 0,
116 "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);
117
118 try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
119 this_ip_dest, -1, 100, 100, 0,
120 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
121
122 try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
123 this_ip_dest, -1, 100, 101, 0,
124 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
125
126 try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
127 this_ip_dest, -1, 101, 100, 0,
128 "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);
129
130 try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
131 this_ip_dest, -1, 100, 100, 8,
132 "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
133
134 try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
135 wrong_addr, -1, 100, 100, 0,
136 "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
137
138 try_accept("Client: Wrong addr", port++, NULL,
139 this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_TIMEOUT);
140
141 try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
142 this_ip_dest, -1, 200, 100, 0,
143 "TCPAOGood", TEST_CNT_GOOD, 0);
144
145 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
146 test_error("Can't convert ip address %s", TEST_NETWORK);
147
148 try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
149 network_addr, 16, 100, 100, 0,
150 "TCPAOGood", TEST_CNT_GOOD, 0);
151
152 try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
153 this_ip_dest, -1, 100, 100, 0,
154 "TCPAOGood", TEST_CNT_GOOD, 0);
155
156 /* client exits */
157 synchronize_threads();
158 return NULL;
159 }
160
try_connect(const char * tst_name,unsigned int port,const char * pwd,union tcp_addr addr,uint8_t prefix,uint8_t sndid,uint8_t rcvid,test_cnt cnt_expected,fault_t inj)161 static void try_connect(const char *tst_name, unsigned int port,
162 const char *pwd, union tcp_addr addr, uint8_t prefix,
163 uint8_t sndid, uint8_t rcvid,
164 test_cnt cnt_expected, fault_t inj)
165 {
166 struct tcp_ao_counters ao_cnt1, ao_cnt2;
167 time_t timeout;
168 int sk, ret;
169
170 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
171 if (sk < 0)
172 test_error("socket()");
173
174 if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
175 test_error("setsockopt(TCP_AO_ADD_KEY)");
176
177 if (pwd && test_get_tcp_ao_counters(sk, &ao_cnt1))
178 test_error("test_get_tcp_ao_counters()");
179
180 synchronize_threads(); /* preparations done */
181
182 timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
183 ret = _test_connect_socket(sk, this_ip_dest, port, timeout);
184
185 synchronize_threads(); /* before counter checks */
186 if (ret < 0) {
187 if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
188 test_ok("%s: connect() was prevented", tst_name);
189 } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
190 test_ok("%s", tst_name);
191 } else if (ret == -ECONNREFUSED &&
192 (fault(TIMEOUT) || fault(KEYREJECT))) {
193 test_ok("%s: refused to connect", tst_name);
194 } else {
195 test_error("%s: connect() returned %d", tst_name, ret);
196 }
197 goto out;
198 }
199
200 if (fault(TIMEOUT) || fault(KEYREJECT))
201 test_fail("%s: connected", tst_name);
202 else
203 test_ok("%s: connected", tst_name);
204 if (pwd && ret > 0) {
205 if (test_get_tcp_ao_counters(sk, &ao_cnt2))
206 test_error("test_get_tcp_ao_counters()");
207 test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
208 }
209 out:
210 synchronize_threads(); /* close() */
211
212 if (ret > 0)
213 close(sk);
214 }
215
client_fn(void * arg)216 static void *client_fn(void *arg)
217 {
218 union tcp_addr wrong_addr, network_addr, addr_any = {};
219 unsigned int port = test_server_port;
220
221 if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
222 test_error("Can't convert ip address %s", TEST_WRONG_IP);
223
224 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
225 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
226 try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
227 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
228
229 trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest,
230 -1, port, 0, 0, 1, 0, 0, 0);
231 try_connect("AO server + Non-AO client", port++, NULL,
232 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
233
234 trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest,
235 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
236 try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
237 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
238
239 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
240 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
241 try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
242 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
243
244 trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
245 port, 0, 100, 100);
246 try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
247 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
248
249 trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest,
250 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
251 try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
252 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
253
254 trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
255 -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
256 try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
257 this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
258
259 try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
260 wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);
261
262 try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
263 this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);
264
265 if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
266 test_error("Can't convert ip address %s", TEST_NETWORK);
267
268 try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
269 this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);
270
271 try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
272 network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);
273
274 return NULL;
275 }
276
main(int argc,char * argv[])277 int main(int argc, char *argv[])
278 {
279 test_init(22, server_fn, client_fn);
280 return 0;
281 }
282