1 // SPDX-License-Identifier: GPL-2.0
2 /* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets
3  * It tests that TCP-AO enabled connection can be restored.
4  * For the proper socket repair see:
5  * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h
6  */
7 #include <fcntl.h>
8 #include <linux/sockios.h>
9 #include <sys/ioctl.h>
10 #include "aolib.h"
11 
12 #ifndef TCPOPT_MAXSEG
13 # define TCPOPT_MAXSEG		2
14 #endif
15 #ifndef TCPOPT_WINDOW
16 # define TCPOPT_WINDOW		3
17 #endif
18 #ifndef TCPOPT_SACK_PERMITTED
19 # define TCPOPT_SACK_PERMITTED	4
20 #endif
21 #ifndef TCPOPT_TIMESTAMP
22 # define TCPOPT_TIMESTAMP	8
23 #endif
24 
25 enum {
26 	TCP_ESTABLISHED = 1,
27 	TCP_SYN_SENT,
28 	TCP_SYN_RECV,
29 	TCP_FIN_WAIT1,
30 	TCP_FIN_WAIT2,
31 	TCP_TIME_WAIT,
32 	TCP_CLOSE,
33 	TCP_CLOSE_WAIT,
34 	TCP_LAST_ACK,
35 	TCP_LISTEN,
36 	TCP_CLOSING,	/* Now a valid state */
37 	TCP_NEW_SYN_RECV,
38 
39 	TCP_MAX_STATES	/* Leave at the end! */
40 };
41 
test_sock_checkpoint_queue(int sk,int queue,int qlen,struct tcp_sock_queue * q)42 static void test_sock_checkpoint_queue(int sk, int queue, int qlen,
43 				       struct tcp_sock_queue *q)
44 {
45 	socklen_t len;
46 	int ret;
47 
48 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
49 		test_error("setsockopt(TCP_REPAIR_QUEUE)");
50 
51 	len = sizeof(q->seq);
52 	ret = getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &q->seq, &len);
53 	if (ret || len != sizeof(q->seq))
54 		test_error("getsockopt(TCP_QUEUE_SEQ): %d", (int)len);
55 
56 	if (!qlen) {
57 		q->buf = NULL;
58 		return;
59 	}
60 
61 	q->buf = malloc(qlen);
62 	if (q->buf == NULL)
63 		test_error("malloc()");
64 	ret = recv(sk, q->buf, qlen, MSG_PEEK | MSG_DONTWAIT);
65 	if (ret != qlen)
66 		test_error("recv(%d): %d", qlen, ret);
67 }
68 
__test_sock_checkpoint(int sk,struct tcp_sock_state * state,void * addr,size_t addr_size)69 void __test_sock_checkpoint(int sk, struct tcp_sock_state *state,
70 			    void *addr, size_t addr_size)
71 {
72 	socklen_t len = sizeof(state->info);
73 	int ret;
74 
75 	memset(state, 0, sizeof(*state));
76 
77 	ret = getsockopt(sk, SOL_TCP, TCP_INFO, &state->info, &len);
78 	if (ret || len != sizeof(state->info))
79 		test_error("getsockopt(TCP_INFO): %d", (int)len);
80 
81 	len = addr_size;
82 	if (getsockname(sk, addr, &len) || len != addr_size)
83 		test_error("getsockname(): %d", (int)len);
84 
85 	len = sizeof(state->trw);
86 	ret = getsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, &len);
87 	if (ret || len != sizeof(state->trw))
88 		test_error("getsockopt(TCP_REPAIR_WINDOW): %d", (int)len);
89 
90 	if (ioctl(sk, SIOCOUTQ, &state->outq_len))
91 		test_error("ioctl(SIOCOUTQ)");
92 
93 	if (ioctl(sk, SIOCOUTQNSD, &state->outq_nsd_len))
94 		test_error("ioctl(SIOCOUTQNSD)");
95 	test_sock_checkpoint_queue(sk, TCP_SEND_QUEUE, state->outq_len, &state->out);
96 
97 	if (ioctl(sk, SIOCINQ, &state->inq_len))
98 		test_error("ioctl(SIOCINQ)");
99 	test_sock_checkpoint_queue(sk, TCP_RECV_QUEUE, state->inq_len, &state->in);
100 
101 	if (state->info.tcpi_state == TCP_CLOSE)
102 		state->outq_len = state->outq_nsd_len = 0;
103 
104 	len = sizeof(state->mss);
105 	ret = getsockopt(sk, SOL_TCP, TCP_MAXSEG, &state->mss, &len);
106 	if (ret || len != sizeof(state->mss))
107 		test_error("getsockopt(TCP_MAXSEG): %d", (int)len);
108 
109 	len = sizeof(state->timestamp);
110 	ret = getsockopt(sk, SOL_TCP, TCP_TIMESTAMP, &state->timestamp, &len);
111 	if (ret || len != sizeof(state->timestamp))
112 		test_error("getsockopt(TCP_TIMESTAMP): %d", (int)len);
113 }
114 
test_ao_checkpoint(int sk,struct tcp_ao_repair * state)115 void test_ao_checkpoint(int sk, struct tcp_ao_repair *state)
116 {
117 	socklen_t len = sizeof(*state);
118 	int ret;
119 
120 	memset(state, 0, sizeof(*state));
121 
122 	ret = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, &len);
123 	if (ret || len != sizeof(*state))
124 		test_error("getsockopt(TCP_AO_REPAIR): %d", (int)len);
125 }
126 
test_sock_restore_seq(int sk,int queue,uint32_t seq)127 static void test_sock_restore_seq(int sk, int queue, uint32_t seq)
128 {
129 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
130 		test_error("setsockopt(TCP_REPAIR_QUEUE)");
131 
132 	if (setsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &seq, sizeof(seq)))
133 		test_error("setsockopt(TCP_QUEUE_SEQ)");
134 }
135 
test_sock_restore_queue(int sk,int queue,void * buf,int len)136 static void test_sock_restore_queue(int sk, int queue, void *buf, int len)
137 {
138 	int chunk = len;
139 	size_t off = 0;
140 
141 	if (len == 0)
142 		return;
143 
144 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
145 		test_error("setsockopt(TCP_REPAIR_QUEUE)");
146 
147 	do {
148 		int ret;
149 
150 		ret = send(sk, buf + off, chunk, 0);
151 		if (ret <= 0) {
152 			if (chunk > 1024) {
153 				chunk >>= 1;
154 				continue;
155 			}
156 			test_error("send()");
157 		}
158 		off += ret;
159 		len -= ret;
160 	} while (len > 0);
161 }
162 
__test_sock_restore(int sk,const char * device,struct tcp_sock_state * state,void * saddr,void * daddr,size_t addr_size)163 void __test_sock_restore(int sk, const char *device,
164 			 struct tcp_sock_state *state,
165 			 void *saddr, void *daddr, size_t addr_size)
166 {
167 	struct tcp_repair_opt opts[4];
168 	unsigned int opt_nr = 0;
169 	long flags;
170 
171 	if (bind(sk, saddr, addr_size))
172 		test_error("bind()");
173 
174 	flags = fcntl(sk, F_GETFL);
175 	if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
176 		test_error("fcntl()");
177 
178 	test_sock_restore_seq(sk, TCP_RECV_QUEUE, state->in.seq - state->inq_len);
179 	test_sock_restore_seq(sk, TCP_SEND_QUEUE, state->out.seq - state->outq_len);
180 
181 	if (device != NULL && setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE,
182 					 device, strlen(device) + 1))
183 		test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
184 
185 	if (connect(sk, daddr, addr_size))
186 		test_error("connect()");
187 
188 	if (state->info.tcpi_options & TCPI_OPT_SACK) {
189 		opts[opt_nr].opt_code = TCPOPT_SACK_PERMITTED;
190 		opts[opt_nr].opt_val = 0;
191 		opt_nr++;
192 	}
193 	if (state->info.tcpi_options & TCPI_OPT_WSCALE) {
194 		opts[opt_nr].opt_code = TCPOPT_WINDOW;
195 		opts[opt_nr].opt_val = state->info.tcpi_snd_wscale +
196 				(state->info.tcpi_rcv_wscale << 16);
197 		opt_nr++;
198 	}
199 	if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
200 		opts[opt_nr].opt_code = TCPOPT_TIMESTAMP;
201 		opts[opt_nr].opt_val = 0;
202 		opt_nr++;
203 	}
204 	opts[opt_nr].opt_code = TCPOPT_MAXSEG;
205 	opts[opt_nr].opt_val = state->mss;
206 	opt_nr++;
207 
208 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_OPTIONS, opts, opt_nr * sizeof(opts[0])))
209 		test_error("setsockopt(TCP_REPAIR_OPTIONS)");
210 
211 	if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
212 		if (setsockopt(sk, SOL_TCP, TCP_TIMESTAMP,
213 			       &state->timestamp, opt_nr * sizeof(opts[0])))
214 			test_error("setsockopt(TCP_TIMESTAMP)");
215 	}
216 	test_sock_restore_queue(sk, TCP_RECV_QUEUE, state->in.buf, state->inq_len);
217 	test_sock_restore_queue(sk, TCP_SEND_QUEUE, state->out.buf, state->outq_len);
218 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, sizeof(state->trw)))
219 		test_error("setsockopt(TCP_REPAIR_WINDOW)");
220 }
221 
test_ao_restore(int sk,struct tcp_ao_repair * state)222 void test_ao_restore(int sk, struct tcp_ao_repair *state)
223 {
224 	if (setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, sizeof(*state)))
225 		test_error("setsockopt(TCP_AO_REPAIR)");
226 }
227 
test_sock_state_free(struct tcp_sock_state * state)228 void test_sock_state_free(struct tcp_sock_state *state)
229 {
230 	free(state->out.buf);
231 	free(state->in.buf);
232 }
233 
test_enable_repair(int sk)234 void test_enable_repair(int sk)
235 {
236 	int val = TCP_REPAIR_ON;
237 
238 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
239 		test_error("setsockopt(TCP_REPAIR)");
240 }
241 
test_disable_repair(int sk)242 void test_disable_repair(int sk)
243 {
244 	int val = TCP_REPAIR_OFF_NO_WP;
245 
246 	if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
247 		test_error("setsockopt(TCP_REPAIR)");
248 }
249 
test_kill_sk(int sk)250 void test_kill_sk(int sk)
251 {
252 	test_enable_repair(sk);
253 	close(sk);
254 }
255