1  // SPDX-License-Identifier: GPL-2.0
2  #include <fcntl.h>
3  #include <pthread.h>
4  #include <sched.h>
5  #include <signal.h>
6  #include "aolib.h"
7  
8  /*
9   * Can't be included in the header: it defines static variables which
10   * will be unique to every object. Let's include it only once here.
11   */
12  #include "../../../kselftest.h"
13  
14  /* Prevent overriding of one thread's output by another */
15  static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER;
16  
__test_msg(const char * buf)17  void __test_msg(const char *buf)
18  {
19  	pthread_mutex_lock(&ksft_print_lock);
20  	ksft_print_msg("%s", buf);
21  	pthread_mutex_unlock(&ksft_print_lock);
22  }
__test_ok(const char * buf)23  void __test_ok(const char *buf)
24  {
25  	pthread_mutex_lock(&ksft_print_lock);
26  	ksft_test_result_pass("%s", buf);
27  	pthread_mutex_unlock(&ksft_print_lock);
28  }
__test_fail(const char * buf)29  void __test_fail(const char *buf)
30  {
31  	pthread_mutex_lock(&ksft_print_lock);
32  	ksft_test_result_fail("%s", buf);
33  	pthread_mutex_unlock(&ksft_print_lock);
34  }
__test_xfail(const char * buf)35  void __test_xfail(const char *buf)
36  {
37  	pthread_mutex_lock(&ksft_print_lock);
38  	ksft_test_result_xfail("%s", buf);
39  	pthread_mutex_unlock(&ksft_print_lock);
40  }
__test_error(const char * buf)41  void __test_error(const char *buf)
42  {
43  	pthread_mutex_lock(&ksft_print_lock);
44  	ksft_test_result_error("%s", buf);
45  	pthread_mutex_unlock(&ksft_print_lock);
46  }
__test_skip(const char * buf)47  void __test_skip(const char *buf)
48  {
49  	pthread_mutex_lock(&ksft_print_lock);
50  	ksft_test_result_skip("%s", buf);
51  	pthread_mutex_unlock(&ksft_print_lock);
52  }
53  
54  static volatile int failed;
55  static volatile int skipped;
56  
test_failed(void)57  void test_failed(void)
58  {
59  	failed = 1;
60  }
61  
test_exit(void)62  static void test_exit(void)
63  {
64  	if (failed) {
65  		ksft_exit_fail();
66  	} else if (skipped) {
67  		/* ksft_exit_skip() is different from ksft_exit_*() */
68  		ksft_print_cnts();
69  		exit(KSFT_SKIP);
70  	} else {
71  		ksft_exit_pass();
72  	}
73  }
74  
75  struct dlist_t {
76  	void (*destruct)(void);
77  	struct dlist_t *next;
78  };
79  static struct dlist_t *destructors_list;
80  
test_add_destructor(void (* d)(void))81  void test_add_destructor(void (*d)(void))
82  {
83  	struct dlist_t *p;
84  
85  	p = malloc(sizeof(struct dlist_t));
86  	if (p == NULL)
87  		test_error("malloc() failed");
88  
89  	p->next = destructors_list;
90  	p->destruct = d;
91  	destructors_list = p;
92  }
93  
94  static void test_destructor(void) __attribute__((destructor));
test_destructor(void)95  static void test_destructor(void)
96  {
97  	while (destructors_list) {
98  		struct dlist_t *p = destructors_list->next;
99  
100  		destructors_list->destruct();
101  		free(destructors_list);
102  		destructors_list = p;
103  	}
104  	test_exit();
105  }
106  
sig_int(int signo)107  static void sig_int(int signo)
108  {
109  	test_error("Caught SIGINT - exiting");
110  }
111  
open_netns(void)112  int open_netns(void)
113  {
114  	const char *netns_path = "/proc/thread-self/ns/net";
115  	int fd;
116  
117  	fd = open(netns_path, O_RDONLY);
118  	if (fd < 0)
119  		test_error("open(%s)", netns_path);
120  	return fd;
121  }
122  
unshare_open_netns(void)123  int unshare_open_netns(void)
124  {
125  	if (unshare(CLONE_NEWNET) != 0)
126  		test_error("unshare()");
127  
128  	return open_netns();
129  }
130  
switch_ns(int fd)131  void switch_ns(int fd)
132  {
133  	if (setns(fd, CLONE_NEWNET))
134  		test_error("setns()");
135  }
136  
switch_save_ns(int new_ns)137  int switch_save_ns(int new_ns)
138  {
139  	int ret = open_netns();
140  
141  	switch_ns(new_ns);
142  	return ret;
143  }
144  
switch_close_ns(int fd)145  void switch_close_ns(int fd)
146  {
147  	if (setns(fd, CLONE_NEWNET))
148  		test_error("setns()");
149  	close(fd);
150  }
151  
152  static int nsfd_outside	= -1;
153  static int nsfd_parent	= -1;
154  static int nsfd_child	= -1;
155  const char veth_name[]	= "ktst-veth";
156  
init_namespaces(void)157  static void init_namespaces(void)
158  {
159  	nsfd_outside = open_netns();
160  	nsfd_parent = unshare_open_netns();
161  	nsfd_child = unshare_open_netns();
162  }
163  
link_init(const char * veth,int family,uint8_t prefix,union tcp_addr addr,union tcp_addr dest)164  static void link_init(const char *veth, int family, uint8_t prefix,
165  		      union tcp_addr addr, union tcp_addr dest)
166  {
167  	if (link_set_up(veth))
168  		test_error("Failed to set link up");
169  	if (ip_addr_add(veth, family, addr, prefix))
170  		test_error("Failed to add ip address");
171  	if (ip_route_add(veth, family, addr, dest))
172  		test_error("Failed to add route");
173  }
174  
175  static unsigned int nr_threads = 1;
176  
177  static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER;
178  static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER;
179  static volatile unsigned int stage_threads[2];
180  static volatile unsigned int stage_nr;
181  
182  /* synchronize all threads in the same stage */
synchronize_threads(void)183  void synchronize_threads(void)
184  {
185  	unsigned int q = stage_nr;
186  
187  	pthread_mutex_lock(&sync_lock);
188  	stage_threads[q]++;
189  	if (stage_threads[q] == nr_threads) {
190  		stage_nr ^= 1;
191  		stage_threads[stage_nr] = 0;
192  		pthread_cond_signal(&sync_cond);
193  	}
194  	while (stage_threads[q] < nr_threads)
195  		pthread_cond_wait(&sync_cond, &sync_lock);
196  	pthread_mutex_unlock(&sync_lock);
197  }
198  
199  __thread union tcp_addr this_ip_addr;
200  __thread union tcp_addr this_ip_dest;
201  int test_family;
202  
203  struct new_pthread_arg {
204  	thread_fn	func;
205  	union tcp_addr	my_ip;
206  	union tcp_addr	dest_ip;
207  };
new_pthread_entry(void * arg)208  static void *new_pthread_entry(void *arg)
209  {
210  	struct new_pthread_arg *p = arg;
211  
212  	this_ip_addr = p->my_ip;
213  	this_ip_dest = p->dest_ip;
214  	p->func(NULL); /* shouldn't return */
215  	exit(KSFT_FAIL);
216  }
217  
__test_skip_all(const char * msg)218  static void __test_skip_all(const char *msg)
219  {
220  	ksft_set_plan(1);
221  	ksft_print_header();
222  	skipped = 1;
223  	test_skip("%s", msg);
224  	exit(KSFT_SKIP);
225  }
226  
__test_init(unsigned int ntests,int family,unsigned int prefix,union tcp_addr addr1,union tcp_addr addr2,thread_fn peer1,thread_fn peer2)227  void __test_init(unsigned int ntests, int family, unsigned int prefix,
228  		 union tcp_addr addr1, union tcp_addr addr2,
229  		 thread_fn peer1, thread_fn peer2)
230  {
231  	struct sigaction sa = {
232  		.sa_handler = sig_int,
233  		.sa_flags = SA_RESTART,
234  	};
235  	time_t seed = time(NULL);
236  
237  	sigemptyset(&sa.sa_mask);
238  	if (sigaction(SIGINT, &sa, NULL))
239  		test_error("Can't set SIGINT handler");
240  
241  	test_family = family;
242  	if (!kernel_config_has(KCONFIG_NET_NS))
243  		__test_skip_all(tests_skip_reason[KCONFIG_NET_NS]);
244  	if (!kernel_config_has(KCONFIG_VETH))
245  		__test_skip_all(tests_skip_reason[KCONFIG_VETH]);
246  	if (!kernel_config_has(KCONFIG_TCP_AO))
247  		__test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]);
248  
249  	ksft_set_plan(ntests);
250  	test_print("rand seed %u", (unsigned int)seed);
251  	srand(seed);
252  
253  	ksft_print_header();
254  	init_namespaces();
255  	test_init_ftrace(nsfd_parent, nsfd_child);
256  
257  	if (add_veth(veth_name, nsfd_parent, nsfd_child))
258  		test_error("Failed to add veth");
259  
260  	switch_ns(nsfd_child);
261  	link_init(veth_name, family, prefix, addr2, addr1);
262  	if (peer2) {
263  		struct new_pthread_arg targ;
264  		pthread_t t;
265  
266  		targ.my_ip = addr2;
267  		targ.dest_ip = addr1;
268  		targ.func = peer2;
269  		nr_threads++;
270  		if (pthread_create(&t, NULL, new_pthread_entry, &targ))
271  			test_error("Failed to create pthread");
272  	}
273  	switch_ns(nsfd_parent);
274  	link_init(veth_name, family, prefix, addr1, addr2);
275  
276  	this_ip_addr = addr1;
277  	this_ip_dest = addr2;
278  	peer1(NULL);
279  	if (failed)
280  		exit(KSFT_FAIL);
281  	else
282  		exit(KSFT_PASS);
283  }
284  
285  /* /proc/sys/net/core/optmem_max artifically limits the amount of memory
286   * that can be allocated with sock_kmalloc() on each socket in the system.
287   * It is not virtualized in v6.7, so it has to written outside test
288   * namespaces. To be nice a test will revert optmem back to the old value.
289   * Keeping it simple without any file lock, which means the tests that
290   * need to set/increase optmem value shouldn't run in parallel.
291   * Also, not re-entrant.
292   * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max")
293   * it is per-namespace, keeping logic for non-virtualized optmem_max
294   * for v6.7, which supports TCP-AO.
295   */
296  static const char *optmem_file = "/proc/sys/net/core/optmem_max";
297  static size_t saved_optmem;
298  static int optmem_ns = -1;
299  
is_optmem_namespaced(void)300  static bool is_optmem_namespaced(void)
301  {
302  	if (optmem_ns == -1) {
303  		int old_ns = switch_save_ns(nsfd_child);
304  
305  		optmem_ns = !access(optmem_file, F_OK);
306  		switch_close_ns(old_ns);
307  	}
308  	return !!optmem_ns;
309  }
310  
test_get_optmem(void)311  size_t test_get_optmem(void)
312  {
313  	int old_ns = 0;
314  	FILE *foptmem;
315  	size_t ret;
316  
317  	if (!is_optmem_namespaced())
318  		old_ns = switch_save_ns(nsfd_outside);
319  	foptmem = fopen(optmem_file, "r");
320  	if (!foptmem)
321  		test_error("failed to open %s", optmem_file);
322  
323  	if (fscanf(foptmem, "%zu", &ret) != 1)
324  		test_error("can't read from %s", optmem_file);
325  	fclose(foptmem);
326  	if (!is_optmem_namespaced())
327  		switch_close_ns(old_ns);
328  	return ret;
329  }
330  
__test_set_optmem(size_t new,size_t * old)331  static void __test_set_optmem(size_t new, size_t *old)
332  {
333  	int old_ns = 0;
334  	FILE *foptmem;
335  
336  	if (old != NULL)
337  		*old = test_get_optmem();
338  
339  	if (!is_optmem_namespaced())
340  		old_ns = switch_save_ns(nsfd_outside);
341  	foptmem = fopen(optmem_file, "w");
342  	if (!foptmem)
343  		test_error("failed to open %s", optmem_file);
344  
345  	if (fprintf(foptmem, "%zu", new) <= 0)
346  		test_error("can't write %zu to %s", new, optmem_file);
347  	fclose(foptmem);
348  	if (!is_optmem_namespaced())
349  		switch_close_ns(old_ns);
350  }
351  
test_revert_optmem(void)352  static void test_revert_optmem(void)
353  {
354  	if (saved_optmem == 0)
355  		return;
356  
357  	__test_set_optmem(saved_optmem, NULL);
358  }
359  
test_set_optmem(size_t value)360  void test_set_optmem(size_t value)
361  {
362  	if (saved_optmem == 0) {
363  		__test_set_optmem(value, &saved_optmem);
364  		test_add_destructor(test_revert_optmem);
365  	} else {
366  		__test_set_optmem(value, NULL);
367  	}
368  }
369