1 // SPDX-License-Identifier: GPL-2.0-only
2 /* MSG_ZEROCOPY feature tests for vsock
3  *
4  * Copyright (C) 2023 SberDevices.
5  *
6  * Author: Arseniy Krasnov <avkrasnov@salutedevices.com>
7  */
8 
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <sys/mman.h>
13 #include <unistd.h>
14 #include <poll.h>
15 #include <linux/errqueue.h>
16 #include <linux/kernel.h>
17 #include <errno.h>
18 
19 #include "control.h"
20 #include "vsock_test_zerocopy.h"
21 #include "msg_zerocopy_common.h"
22 
23 #ifndef PAGE_SIZE
24 #define PAGE_SIZE		4096
25 #endif
26 
27 #define VSOCK_TEST_DATA_MAX_IOV 3
28 
29 struct vsock_test_data {
30 	/* This test case if for SOCK_STREAM only. */
31 	bool stream_only;
32 	/* Data must be zerocopied. This field is checked against
33 	 * field 'ee_code' of the 'struct sock_extended_err', which
34 	 * contains bit to detect that zerocopy transmission was
35 	 * fallbacked to copy mode.
36 	 */
37 	bool zerocopied;
38 	/* Enable SO_ZEROCOPY option on the socket. Without enabled
39 	 * SO_ZEROCOPY, every MSG_ZEROCOPY transmission will behave
40 	 * like without MSG_ZEROCOPY flag.
41 	 */
42 	bool so_zerocopy;
43 	/* 'errno' after 'sendmsg()' call. */
44 	int sendmsg_errno;
45 	/* Number of valid elements in 'vecs'. */
46 	int vecs_cnt;
47 	struct iovec vecs[VSOCK_TEST_DATA_MAX_IOV];
48 };
49 
50 static struct vsock_test_data test_data_array[] = {
51 	/* Last element has non-page aligned size. */
52 	{
53 		.zerocopied = true,
54 		.so_zerocopy = true,
55 		.sendmsg_errno = 0,
56 		.vecs_cnt = 3,
57 		{
58 			{ NULL, PAGE_SIZE },
59 			{ NULL, PAGE_SIZE },
60 			{ NULL, 200 }
61 		}
62 	},
63 	/* All elements have page aligned base and size. */
64 	{
65 		.zerocopied = true,
66 		.so_zerocopy = true,
67 		.sendmsg_errno = 0,
68 		.vecs_cnt = 3,
69 		{
70 			{ NULL, PAGE_SIZE },
71 			{ NULL, PAGE_SIZE * 2 },
72 			{ NULL, PAGE_SIZE * 3 }
73 		}
74 	},
75 	/* All elements have page aligned base and size. But
76 	 * data length is bigger than 64Kb.
77 	 */
78 	{
79 		.zerocopied = true,
80 		.so_zerocopy = true,
81 		.sendmsg_errno = 0,
82 		.vecs_cnt = 3,
83 		{
84 			{ NULL, PAGE_SIZE * 16 },
85 			{ NULL, PAGE_SIZE * 16 },
86 			{ NULL, PAGE_SIZE * 16 }
87 		}
88 	},
89 	/* Middle element has both non-page aligned base and size. */
90 	{
91 		.zerocopied = true,
92 		.so_zerocopy = true,
93 		.sendmsg_errno = 0,
94 		.vecs_cnt = 3,
95 		{
96 			{ NULL, PAGE_SIZE },
97 			{ (void *)1, 100 },
98 			{ NULL, PAGE_SIZE }
99 		}
100 	},
101 	/* Middle element is unmapped. */
102 	{
103 		.zerocopied = false,
104 		.so_zerocopy = true,
105 		.sendmsg_errno = ENOMEM,
106 		.vecs_cnt = 3,
107 		{
108 			{ NULL, PAGE_SIZE },
109 			{ MAP_FAILED, PAGE_SIZE },
110 			{ NULL, PAGE_SIZE }
111 		}
112 	},
113 	/* Valid data, but SO_ZEROCOPY is off. This
114 	 * will trigger fallback to copy.
115 	 */
116 	{
117 		.zerocopied = false,
118 		.so_zerocopy = false,
119 		.sendmsg_errno = 0,
120 		.vecs_cnt = 1,
121 		{
122 			{ NULL, PAGE_SIZE }
123 		}
124 	},
125 	/* Valid data, but message is bigger than peer's
126 	 * buffer, so this will trigger fallback to copy.
127 	 * This test is for SOCK_STREAM only, because
128 	 * for SOCK_SEQPACKET, 'sendmsg()' returns EMSGSIZE.
129 	 */
130 	{
131 		.stream_only = true,
132 		.zerocopied = false,
133 		.so_zerocopy = true,
134 		.sendmsg_errno = 0,
135 		.vecs_cnt = 1,
136 		{
137 			{ NULL, 100 * PAGE_SIZE }
138 		}
139 	},
140 };
141 
142 #define POLL_TIMEOUT_MS		100
143 
test_client(const struct test_opts * opts,const struct vsock_test_data * test_data,bool sock_seqpacket)144 static void test_client(const struct test_opts *opts,
145 			const struct vsock_test_data *test_data,
146 			bool sock_seqpacket)
147 {
148 	struct pollfd fds = { 0 };
149 	struct msghdr msg = { 0 };
150 	ssize_t sendmsg_res;
151 	struct iovec *iovec;
152 	int fd;
153 
154 	if (sock_seqpacket)
155 		fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
156 	else
157 		fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
158 
159 	if (fd < 0) {
160 		perror("connect");
161 		exit(EXIT_FAILURE);
162 	}
163 
164 	if (test_data->so_zerocopy)
165 		enable_so_zerocopy(fd);
166 
167 	iovec = alloc_test_iovec(test_data->vecs, test_data->vecs_cnt);
168 
169 	msg.msg_iov = iovec;
170 	msg.msg_iovlen = test_data->vecs_cnt;
171 
172 	errno = 0;
173 
174 	sendmsg_res = sendmsg(fd, &msg, MSG_ZEROCOPY);
175 	if (errno != test_data->sendmsg_errno) {
176 		fprintf(stderr, "expected 'errno' == %i, got %i\n",
177 			test_data->sendmsg_errno, errno);
178 		exit(EXIT_FAILURE);
179 	}
180 
181 	if (!errno) {
182 		if (sendmsg_res != iovec_bytes(iovec, test_data->vecs_cnt)) {
183 			fprintf(stderr, "expected 'sendmsg()' == %li, got %li\n",
184 				iovec_bytes(iovec, test_data->vecs_cnt),
185 				sendmsg_res);
186 			exit(EXIT_FAILURE);
187 		}
188 	}
189 
190 	fds.fd = fd;
191 	fds.events = 0;
192 
193 	if (poll(&fds, 1, POLL_TIMEOUT_MS) < 0) {
194 		perror("poll");
195 		exit(EXIT_FAILURE);
196 	}
197 
198 	if (fds.revents & POLLERR) {
199 		vsock_recv_completion(fd, &test_data->zerocopied);
200 	} else if (test_data->so_zerocopy && !test_data->sendmsg_errno) {
201 		/* If we don't have data in the error queue, but
202 		 * SO_ZEROCOPY was enabled and 'sendmsg()' was
203 		 * successful - this is an error.
204 		 */
205 		fprintf(stderr, "POLLERR expected\n");
206 		exit(EXIT_FAILURE);
207 	}
208 
209 	if (!test_data->sendmsg_errno)
210 		control_writeulong(iovec_hash_djb2(iovec, test_data->vecs_cnt));
211 	else
212 		control_writeulong(0);
213 
214 	control_writeln("DONE");
215 	free_test_iovec(test_data->vecs, iovec, test_data->vecs_cnt);
216 	close(fd);
217 }
218 
test_stream_msgzcopy_client(const struct test_opts * opts)219 void test_stream_msgzcopy_client(const struct test_opts *opts)
220 {
221 	int i;
222 
223 	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
224 		test_client(opts, &test_data_array[i], false);
225 }
226 
test_seqpacket_msgzcopy_client(const struct test_opts * opts)227 void test_seqpacket_msgzcopy_client(const struct test_opts *opts)
228 {
229 	int i;
230 
231 	for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
232 		if (test_data_array[i].stream_only)
233 			continue;
234 
235 		test_client(opts, &test_data_array[i], true);
236 	}
237 }
238 
test_server(const struct test_opts * opts,const struct vsock_test_data * test_data,bool sock_seqpacket)239 static void test_server(const struct test_opts *opts,
240 			const struct vsock_test_data *test_data,
241 			bool sock_seqpacket)
242 {
243 	unsigned long remote_hash;
244 	unsigned long local_hash;
245 	ssize_t total_bytes_rec;
246 	unsigned char *data;
247 	size_t data_len;
248 	int fd;
249 
250 	if (sock_seqpacket)
251 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
252 	else
253 		fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
254 
255 	if (fd < 0) {
256 		perror("accept");
257 		exit(EXIT_FAILURE);
258 	}
259 
260 	data_len = iovec_bytes(test_data->vecs, test_data->vecs_cnt);
261 
262 	data = malloc(data_len);
263 	if (!data) {
264 		perror("malloc");
265 		exit(EXIT_FAILURE);
266 	}
267 
268 	total_bytes_rec = 0;
269 
270 	while (total_bytes_rec != data_len) {
271 		ssize_t bytes_rec;
272 
273 		bytes_rec = read(fd, data + total_bytes_rec,
274 				 data_len - total_bytes_rec);
275 		if (bytes_rec <= 0)
276 			break;
277 
278 		total_bytes_rec += bytes_rec;
279 	}
280 
281 	if (test_data->sendmsg_errno == 0)
282 		local_hash = hash_djb2(data, data_len);
283 	else
284 		local_hash = 0;
285 
286 	free(data);
287 
288 	/* Waiting for some result. */
289 	remote_hash = control_readulong();
290 	if (remote_hash != local_hash) {
291 		fprintf(stderr, "hash mismatch\n");
292 		exit(EXIT_FAILURE);
293 	}
294 
295 	control_expectln("DONE");
296 	close(fd);
297 }
298 
test_stream_msgzcopy_server(const struct test_opts * opts)299 void test_stream_msgzcopy_server(const struct test_opts *opts)
300 {
301 	int i;
302 
303 	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
304 		test_server(opts, &test_data_array[i], false);
305 }
306 
test_seqpacket_msgzcopy_server(const struct test_opts * opts)307 void test_seqpacket_msgzcopy_server(const struct test_opts *opts)
308 {
309 	int i;
310 
311 	for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
312 		if (test_data_array[i].stream_only)
313 			continue;
314 
315 		test_server(opts, &test_data_array[i], true);
316 	}
317 }
318 
test_stream_msgzcopy_empty_errq_client(const struct test_opts * opts)319 void test_stream_msgzcopy_empty_errq_client(const struct test_opts *opts)
320 {
321 	struct msghdr msg = { 0 };
322 	char cmsg_data[128];
323 	ssize_t res;
324 	int fd;
325 
326 	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
327 	if (fd < 0) {
328 		perror("connect");
329 		exit(EXIT_FAILURE);
330 	}
331 
332 	msg.msg_control = cmsg_data;
333 	msg.msg_controllen = sizeof(cmsg_data);
334 
335 	res = recvmsg(fd, &msg, MSG_ERRQUEUE);
336 	if (res != -1) {
337 		fprintf(stderr, "expected 'recvmsg(2)' failure, got %zi\n",
338 			res);
339 		exit(EXIT_FAILURE);
340 	}
341 
342 	control_writeln("DONE");
343 	close(fd);
344 }
345 
test_stream_msgzcopy_empty_errq_server(const struct test_opts * opts)346 void test_stream_msgzcopy_empty_errq_server(const struct test_opts *opts)
347 {
348 	int fd;
349 
350 	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
351 	if (fd < 0) {
352 		perror("accept");
353 		exit(EXIT_FAILURE);
354 	}
355 
356 	control_expectln("DONE");
357 	close(fd);
358 }
359