1 // SPDX-License-Identifier: GPL-2.0
2 #define _GNU_SOURCE
3 #define __EXPORTED_HEADERS__
4 
5 #include <linux/uio.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <unistd.h>
9 #include <stdbool.h>
10 #include <string.h>
11 #include <errno.h>
12 #define __iovec_defined
13 #include <fcntl.h>
14 #include <malloc.h>
15 #include <error.h>
16 
17 #include <arpa/inet.h>
18 #include <sys/socket.h>
19 #include <sys/mman.h>
20 #include <sys/ioctl.h>
21 #include <sys/syscall.h>
22 
23 #include <linux/memfd.h>
24 #include <linux/dma-buf.h>
25 #include <linux/udmabuf.h>
26 #include <libmnl/libmnl.h>
27 #include <linux/types.h>
28 #include <linux/netlink.h>
29 #include <linux/genetlink.h>
30 #include <linux/netdev.h>
31 #include <time.h>
32 #include <net/if.h>
33 
34 #include "netdev-user.h"
35 #include <ynl.h>
36 
37 #define PAGE_SHIFT 12
38 #define TEST_PREFIX "ncdevmem"
39 #define NUM_PAGES 16000
40 
41 #ifndef MSG_SOCK_DEVMEM
42 #define MSG_SOCK_DEVMEM 0x2000000
43 #endif
44 
45 /*
46  * tcpdevmem netcat. Works similarly to netcat but does device memory TCP
47  * instead of regular TCP. Uses udmabuf to mock a dmabuf provider.
48  *
49  * Usage:
50  *
51  *	On server:
52  *	ncdevmem -s <server IP> -c <client IP> -f eth1 -l -p 5201 -v 7
53  *
54  *	On client:
55  *	yes $(echo -e \\x01\\x02\\x03\\x04\\x05\\x06) | \
56  *		tr \\n \\0 | \
57  *		head -c 5G | \
58  *		nc <server IP> 5201 -p 5201
59  *
60  * Note this is compatible with regular netcat. i.e. the sender or receiver can
61  * be replaced with regular netcat to test the RX or TX path in isolation.
62  */
63 
64 static char *server_ip = "192.168.1.4";
65 static char *client_ip = "192.168.1.2";
66 static char *port = "5201";
67 static size_t do_validation;
68 static int start_queue = 8;
69 static int num_queues = 8;
70 static char *ifname = "eth1";
71 static unsigned int ifindex;
72 static unsigned int dmabuf_id;
73 
print_bytes(void * ptr,size_t size)74 void print_bytes(void *ptr, size_t size)
75 {
76 	unsigned char *p = ptr;
77 	int i;
78 
79 	for (i = 0; i < size; i++)
80 		printf("%02hhX ", p[i]);
81 	printf("\n");
82 }
83 
print_nonzero_bytes(void * ptr,size_t size)84 void print_nonzero_bytes(void *ptr, size_t size)
85 {
86 	unsigned char *p = ptr;
87 	unsigned int i;
88 
89 	for (i = 0; i < size; i++)
90 		putchar(p[i]);
91 	printf("\n");
92 }
93 
validate_buffer(void * line,size_t size)94 void validate_buffer(void *line, size_t size)
95 {
96 	static unsigned char seed = 1;
97 	unsigned char *ptr = line;
98 	int errors = 0;
99 	size_t i;
100 
101 	for (i = 0; i < size; i++) {
102 		if (ptr[i] != seed) {
103 			fprintf(stderr,
104 				"Failed validation: expected=%u, actual=%u, index=%lu\n",
105 				seed, ptr[i], i);
106 			errors++;
107 			if (errors > 20)
108 				error(1, 0, "validation failed.");
109 		}
110 		seed++;
111 		if (seed == do_validation)
112 			seed = 0;
113 	}
114 
115 	fprintf(stdout, "Validated buffer\n");
116 }
117 
118 #define run_command(cmd, ...)                                           \
119 	({                                                              \
120 		char command[256];                                      \
121 		memset(command, 0, sizeof(command));                    \
122 		snprintf(command, sizeof(command), cmd, ##__VA_ARGS__); \
123 		printf("Running: %s\n", command);                       \
124 		system(command);                                        \
125 	})
126 
reset_flow_steering(void)127 static int reset_flow_steering(void)
128 {
129 	int ret = 0;
130 
131 	ret = run_command("sudo ethtool -K %s ntuple off", ifname);
132 	if (ret)
133 		return ret;
134 
135 	return run_command("sudo ethtool -K %s ntuple on", ifname);
136 }
137 
configure_headersplit(bool on)138 static int configure_headersplit(bool on)
139 {
140 	return run_command("sudo ethtool -G %s tcp-data-split %s", ifname,
141 			   on ? "on" : "off");
142 }
143 
configure_rss(void)144 static int configure_rss(void)
145 {
146 	return run_command("sudo ethtool -X %s equal %d", ifname, start_queue);
147 }
148 
configure_channels(unsigned int rx,unsigned int tx)149 static int configure_channels(unsigned int rx, unsigned int tx)
150 {
151 	return run_command("sudo ethtool -L %s rx %u tx %u", ifname, rx, tx);
152 }
153 
configure_flow_steering(void)154 static int configure_flow_steering(void)
155 {
156 	return run_command("sudo ethtool -N %s flow-type tcp4 src-ip %s dst-ip %s src-port %s dst-port %s queue %d",
157 			   ifname, client_ip, server_ip, port, port, start_queue);
158 }
159 
bind_rx_queue(unsigned int ifindex,unsigned int dmabuf_fd,struct netdev_queue_id * queues,unsigned int n_queue_index,struct ynl_sock ** ys)160 static int bind_rx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
161 			 struct netdev_queue_id *queues,
162 			 unsigned int n_queue_index, struct ynl_sock **ys)
163 {
164 	struct netdev_bind_rx_req *req = NULL;
165 	struct netdev_bind_rx_rsp *rsp = NULL;
166 	struct ynl_error yerr;
167 
168 	*ys = ynl_sock_create(&ynl_netdev_family, &yerr);
169 	if (!*ys) {
170 		fprintf(stderr, "YNL: %s\n", yerr.msg);
171 		return -1;
172 	}
173 
174 	req = netdev_bind_rx_req_alloc();
175 	netdev_bind_rx_req_set_ifindex(req, ifindex);
176 	netdev_bind_rx_req_set_fd(req, dmabuf_fd);
177 	__netdev_bind_rx_req_set_queues(req, queues, n_queue_index);
178 
179 	rsp = netdev_bind_rx(*ys, req);
180 	if (!rsp) {
181 		perror("netdev_bind_rx");
182 		goto err_close;
183 	}
184 
185 	if (!rsp->_present.id) {
186 		perror("id not present");
187 		goto err_close;
188 	}
189 
190 	printf("got dmabuf id=%d\n", rsp->id);
191 	dmabuf_id = rsp->id;
192 
193 	netdev_bind_rx_req_free(req);
194 	netdev_bind_rx_rsp_free(rsp);
195 
196 	return 0;
197 
198 err_close:
199 	fprintf(stderr, "YNL failed: %s\n", (*ys)->err.msg);
200 	netdev_bind_rx_req_free(req);
201 	ynl_sock_destroy(*ys);
202 	return -1;
203 }
204 
create_udmabuf(int * devfd,int * memfd,int * buf,size_t dmabuf_size)205 static void create_udmabuf(int *devfd, int *memfd, int *buf, size_t dmabuf_size)
206 {
207 	struct udmabuf_create create;
208 	int ret;
209 
210 	*devfd = open("/dev/udmabuf", O_RDWR);
211 	if (*devfd < 0) {
212 		error(70, 0,
213 		      "%s: [skip,no-udmabuf: Unable to access DMA buffer device file]\n",
214 		      TEST_PREFIX);
215 	}
216 
217 	*memfd = memfd_create("udmabuf-test", MFD_ALLOW_SEALING);
218 	if (*memfd < 0)
219 		error(70, 0, "%s: [skip,no-memfd]\n", TEST_PREFIX);
220 
221 	/* Required for udmabuf */
222 	ret = fcntl(*memfd, F_ADD_SEALS, F_SEAL_SHRINK);
223 	if (ret < 0)
224 		error(73, 0, "%s: [skip,fcntl-add-seals]\n", TEST_PREFIX);
225 
226 	ret = ftruncate(*memfd, dmabuf_size);
227 	if (ret == -1)
228 		error(74, 0, "%s: [FAIL,memfd-truncate]\n", TEST_PREFIX);
229 
230 	memset(&create, 0, sizeof(create));
231 
232 	create.memfd = *memfd;
233 	create.offset = 0;
234 	create.size = dmabuf_size;
235 	*buf = ioctl(*devfd, UDMABUF_CREATE, &create);
236 	if (*buf < 0)
237 		error(75, 0, "%s: [FAIL, create udmabuf]\n", TEST_PREFIX);
238 }
239 
do_server(void)240 int do_server(void)
241 {
242 	char ctrl_data[sizeof(int) * 20000];
243 	struct netdev_queue_id *queues;
244 	size_t non_page_aligned_frags = 0;
245 	struct sockaddr_in client_addr;
246 	struct sockaddr_in server_sin;
247 	size_t page_aligned_frags = 0;
248 	int devfd, memfd, buf, ret;
249 	size_t total_received = 0;
250 	socklen_t client_addr_len;
251 	bool is_devmem = false;
252 	char *buf_mem = NULL;
253 	struct ynl_sock *ys;
254 	size_t dmabuf_size;
255 	char iobuf[819200];
256 	char buffer[256];
257 	int socket_fd;
258 	int client_fd;
259 	size_t i = 0;
260 	int opt = 1;
261 
262 	dmabuf_size = getpagesize() * NUM_PAGES;
263 
264 	create_udmabuf(&devfd, &memfd, &buf, dmabuf_size);
265 
266 	if (reset_flow_steering())
267 		error(1, 0, "Failed to reset flow steering\n");
268 
269 	/* Configure RSS to divert all traffic from our devmem queues */
270 	if (configure_rss())
271 		error(1, 0, "Failed to configure rss\n");
272 
273 	/* Flow steer our devmem flows to start_queue */
274 	if (configure_flow_steering())
275 		error(1, 0, "Failed to configure flow steering\n");
276 
277 	sleep(1);
278 
279 	queues = malloc(sizeof(*queues) * num_queues);
280 
281 	for (i = 0; i < num_queues; i++) {
282 		queues[i]._present.type = 1;
283 		queues[i]._present.id = 1;
284 		queues[i].type = NETDEV_QUEUE_TYPE_RX;
285 		queues[i].id = start_queue + i;
286 	}
287 
288 	if (bind_rx_queue(ifindex, buf, queues, num_queues, &ys))
289 		error(1, 0, "Failed to bind\n");
290 
291 	buf_mem = mmap(NULL, dmabuf_size, PROT_READ | PROT_WRITE, MAP_SHARED,
292 		       buf, 0);
293 	if (buf_mem == MAP_FAILED)
294 		error(1, 0, "mmap()");
295 
296 	server_sin.sin_family = AF_INET;
297 	server_sin.sin_port = htons(atoi(port));
298 
299 	ret = inet_pton(server_sin.sin_family, server_ip, &server_sin.sin_addr);
300 	if (socket < 0)
301 		error(79, 0, "%s: [FAIL, create socket]\n", TEST_PREFIX);
302 
303 	socket_fd = socket(server_sin.sin_family, SOCK_STREAM, 0);
304 	if (socket < 0)
305 		error(errno, errno, "%s: [FAIL, create socket]\n", TEST_PREFIX);
306 
307 	ret = setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &opt,
308 			 sizeof(opt));
309 	if (ret)
310 		error(errno, errno, "%s: [FAIL, set sock opt]\n", TEST_PREFIX);
311 
312 	ret = setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &opt,
313 			 sizeof(opt));
314 	if (ret)
315 		error(errno, errno, "%s: [FAIL, set sock opt]\n", TEST_PREFIX);
316 
317 	printf("binding to address %s:%d\n", server_ip,
318 	       ntohs(server_sin.sin_port));
319 
320 	ret = bind(socket_fd, &server_sin, sizeof(server_sin));
321 	if (ret)
322 		error(errno, errno, "%s: [FAIL, bind]\n", TEST_PREFIX);
323 
324 	ret = listen(socket_fd, 1);
325 	if (ret)
326 		error(errno, errno, "%s: [FAIL, listen]\n", TEST_PREFIX);
327 
328 	client_addr_len = sizeof(client_addr);
329 
330 	inet_ntop(server_sin.sin_family, &server_sin.sin_addr, buffer,
331 		  sizeof(buffer));
332 	printf("Waiting or connection on %s:%d\n", buffer,
333 	       ntohs(server_sin.sin_port));
334 	client_fd = accept(socket_fd, &client_addr, &client_addr_len);
335 
336 	inet_ntop(client_addr.sin_family, &client_addr.sin_addr, buffer,
337 		  sizeof(buffer));
338 	printf("Got connection from %s:%d\n", buffer,
339 	       ntohs(client_addr.sin_port));
340 
341 	while (1) {
342 		struct iovec iov = { .iov_base = iobuf,
343 				     .iov_len = sizeof(iobuf) };
344 		struct dmabuf_cmsg *dmabuf_cmsg = NULL;
345 		struct dma_buf_sync sync = { 0 };
346 		struct cmsghdr *cm = NULL;
347 		struct msghdr msg = { 0 };
348 		struct dmabuf_token token;
349 		ssize_t ret;
350 
351 		is_devmem = false;
352 		printf("\n\n");
353 
354 		msg.msg_iov = &iov;
355 		msg.msg_iovlen = 1;
356 		msg.msg_control = ctrl_data;
357 		msg.msg_controllen = sizeof(ctrl_data);
358 		ret = recvmsg(client_fd, &msg, MSG_SOCK_DEVMEM);
359 		printf("recvmsg ret=%ld\n", ret);
360 		if (ret < 0 && (errno == EAGAIN || errno == EWOULDBLOCK))
361 			continue;
362 		if (ret < 0) {
363 			perror("recvmsg");
364 			continue;
365 		}
366 		if (ret == 0) {
367 			printf("client exited\n");
368 			goto cleanup;
369 		}
370 
371 		i++;
372 		for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm)) {
373 			if (cm->cmsg_level != SOL_SOCKET ||
374 			    (cm->cmsg_type != SCM_DEVMEM_DMABUF &&
375 			     cm->cmsg_type != SCM_DEVMEM_LINEAR)) {
376 				fprintf(stdout, "skipping non-devmem cmsg\n");
377 				continue;
378 			}
379 
380 			dmabuf_cmsg = (struct dmabuf_cmsg *)CMSG_DATA(cm);
381 			is_devmem = true;
382 
383 			if (cm->cmsg_type == SCM_DEVMEM_LINEAR) {
384 				/* TODO: process data copied from skb's linear
385 				 * buffer.
386 				 */
387 				fprintf(stdout,
388 					"SCM_DEVMEM_LINEAR. dmabuf_cmsg->frag_size=%u\n",
389 					dmabuf_cmsg->frag_size);
390 
391 				continue;
392 			}
393 
394 			token.token_start = dmabuf_cmsg->frag_token;
395 			token.token_count = 1;
396 
397 			total_received += dmabuf_cmsg->frag_size;
398 			printf("received frag_page=%llu, in_page_offset=%llu, frag_offset=%llu, frag_size=%u, token=%u, total_received=%lu, dmabuf_id=%u\n",
399 			       dmabuf_cmsg->frag_offset >> PAGE_SHIFT,
400 			       dmabuf_cmsg->frag_offset % getpagesize(),
401 			       dmabuf_cmsg->frag_offset, dmabuf_cmsg->frag_size,
402 			       dmabuf_cmsg->frag_token, total_received,
403 			       dmabuf_cmsg->dmabuf_id);
404 
405 			if (dmabuf_cmsg->dmabuf_id != dmabuf_id)
406 				error(1, 0,
407 				      "received on wrong dmabuf_id: flow steering error\n");
408 
409 			if (dmabuf_cmsg->frag_size % getpagesize())
410 				non_page_aligned_frags++;
411 			else
412 				page_aligned_frags++;
413 
414 			sync.flags = DMA_BUF_SYNC_READ | DMA_BUF_SYNC_START;
415 			ioctl(buf, DMA_BUF_IOCTL_SYNC, &sync);
416 
417 			if (do_validation)
418 				validate_buffer(
419 					((unsigned char *)buf_mem) +
420 						dmabuf_cmsg->frag_offset,
421 					dmabuf_cmsg->frag_size);
422 			else
423 				print_nonzero_bytes(
424 					((unsigned char *)buf_mem) +
425 						dmabuf_cmsg->frag_offset,
426 					dmabuf_cmsg->frag_size);
427 
428 			sync.flags = DMA_BUF_SYNC_READ | DMA_BUF_SYNC_END;
429 			ioctl(buf, DMA_BUF_IOCTL_SYNC, &sync);
430 
431 			ret = setsockopt(client_fd, SOL_SOCKET,
432 					 SO_DEVMEM_DONTNEED, &token,
433 					 sizeof(token));
434 			if (ret != 1)
435 				error(1, 0,
436 				      "SO_DEVMEM_DONTNEED not enough tokens");
437 		}
438 		if (!is_devmem)
439 			error(1, 0, "flow steering error\n");
440 
441 		printf("total_received=%lu\n", total_received);
442 	}
443 
444 	fprintf(stdout, "%s: ok\n", TEST_PREFIX);
445 
446 	fprintf(stdout, "page_aligned_frags=%lu, non_page_aligned_frags=%lu\n",
447 		page_aligned_frags, non_page_aligned_frags);
448 
449 	fprintf(stdout, "page_aligned_frags=%lu, non_page_aligned_frags=%lu\n",
450 		page_aligned_frags, non_page_aligned_frags);
451 
452 cleanup:
453 
454 	munmap(buf_mem, dmabuf_size);
455 	close(client_fd);
456 	close(socket_fd);
457 	close(buf);
458 	close(memfd);
459 	close(devfd);
460 	ynl_sock_destroy(ys);
461 
462 	return 0;
463 }
464 
run_devmem_tests(void)465 void run_devmem_tests(void)
466 {
467 	struct netdev_queue_id *queues;
468 	int devfd, memfd, buf;
469 	struct ynl_sock *ys;
470 	size_t dmabuf_size;
471 	size_t i = 0;
472 
473 	dmabuf_size = getpagesize() * NUM_PAGES;
474 
475 	create_udmabuf(&devfd, &memfd, &buf, dmabuf_size);
476 
477 	/* Configure RSS to divert all traffic from our devmem queues */
478 	if (configure_rss())
479 		error(1, 0, "rss error\n");
480 
481 	queues = calloc(num_queues, sizeof(*queues));
482 
483 	if (configure_headersplit(1))
484 		error(1, 0, "Failed to configure header split\n");
485 
486 	if (!bind_rx_queue(ifindex, buf, queues, num_queues, &ys))
487 		error(1, 0, "Binding empty queues array should have failed\n");
488 
489 	for (i = 0; i < num_queues; i++) {
490 		queues[i]._present.type = 1;
491 		queues[i]._present.id = 1;
492 		queues[i].type = NETDEV_QUEUE_TYPE_RX;
493 		queues[i].id = start_queue + i;
494 	}
495 
496 	if (configure_headersplit(0))
497 		error(1, 0, "Failed to configure header split\n");
498 
499 	if (!bind_rx_queue(ifindex, buf, queues, num_queues, &ys))
500 		error(1, 0, "Configure dmabuf with header split off should have failed\n");
501 
502 	if (configure_headersplit(1))
503 		error(1, 0, "Failed to configure header split\n");
504 
505 	for (i = 0; i < num_queues; i++) {
506 		queues[i]._present.type = 1;
507 		queues[i]._present.id = 1;
508 		queues[i].type = NETDEV_QUEUE_TYPE_RX;
509 		queues[i].id = start_queue + i;
510 	}
511 
512 	if (bind_rx_queue(ifindex, buf, queues, num_queues, &ys))
513 		error(1, 0, "Failed to bind\n");
514 
515 	/* Deactivating a bound queue should not be legal */
516 	if (!configure_channels(num_queues, num_queues - 1))
517 		error(1, 0, "Deactivating a bound queue should be illegal.\n");
518 
519 	/* Closing the netlink socket does an implicit unbind */
520 	ynl_sock_destroy(ys);
521 }
522 
main(int argc,char * argv[])523 int main(int argc, char *argv[])
524 {
525 	int is_server = 0, opt;
526 
527 	while ((opt = getopt(argc, argv, "ls:c:p:v:q:t:f:")) != -1) {
528 		switch (opt) {
529 		case 'l':
530 			is_server = 1;
531 			break;
532 		case 's':
533 			server_ip = optarg;
534 			break;
535 		case 'c':
536 			client_ip = optarg;
537 			break;
538 		case 'p':
539 			port = optarg;
540 			break;
541 		case 'v':
542 			do_validation = atoll(optarg);
543 			break;
544 		case 'q':
545 			num_queues = atoi(optarg);
546 			break;
547 		case 't':
548 			start_queue = atoi(optarg);
549 			break;
550 		case 'f':
551 			ifname = optarg;
552 			break;
553 		case '?':
554 			printf("unknown option: %c\n", optopt);
555 			break;
556 		}
557 	}
558 
559 	ifindex = if_nametoindex(ifname);
560 
561 	for (; optind < argc; optind++)
562 		printf("extra arguments: %s\n", argv[optind]);
563 
564 	run_devmem_tests();
565 
566 	if (is_server)
567 		return do_server();
568 
569 	return 0;
570 }
571