1  // SPDX-License-Identifier: GPL-2.0
2  
3  #include <kunit/test.h>
4  
5  #include "utils.h"
6  
7  struct mctp_test_route {
8  	struct mctp_route	rt;
9  	struct sk_buff_head	pkts;
10  };
11  
mctp_test_route_output(struct mctp_route * rt,struct sk_buff * skb)12  static int mctp_test_route_output(struct mctp_route *rt, struct sk_buff *skb)
13  {
14  	struct mctp_test_route *test_rt = container_of(rt, struct mctp_test_route, rt);
15  
16  	skb_queue_tail(&test_rt->pkts, skb);
17  
18  	return 0;
19  }
20  
21  /* local version of mctp_route_alloc() */
mctp_route_test_alloc(void)22  static struct mctp_test_route *mctp_route_test_alloc(void)
23  {
24  	struct mctp_test_route *rt;
25  
26  	rt = kzalloc(sizeof(*rt), GFP_KERNEL);
27  	if (!rt)
28  		return NULL;
29  
30  	INIT_LIST_HEAD(&rt->rt.list);
31  	refcount_set(&rt->rt.refs, 1);
32  	rt->rt.output = mctp_test_route_output;
33  
34  	skb_queue_head_init(&rt->pkts);
35  
36  	return rt;
37  }
38  
mctp_test_create_route(struct net * net,struct mctp_dev * dev,mctp_eid_t eid,unsigned int mtu)39  static struct mctp_test_route *mctp_test_create_route(struct net *net,
40  						      struct mctp_dev *dev,
41  						      mctp_eid_t eid,
42  						      unsigned int mtu)
43  {
44  	struct mctp_test_route *rt;
45  
46  	rt = mctp_route_test_alloc();
47  	if (!rt)
48  		return NULL;
49  
50  	rt->rt.min = eid;
51  	rt->rt.max = eid;
52  	rt->rt.mtu = mtu;
53  	rt->rt.type = RTN_UNSPEC;
54  	if (dev)
55  		mctp_dev_hold(dev);
56  	rt->rt.dev = dev;
57  
58  	list_add_rcu(&rt->rt.list, &net->mctp.routes);
59  
60  	return rt;
61  }
62  
mctp_test_route_destroy(struct kunit * test,struct mctp_test_route * rt)63  static void mctp_test_route_destroy(struct kunit *test,
64  				    struct mctp_test_route *rt)
65  {
66  	unsigned int refs;
67  
68  	rtnl_lock();
69  	list_del_rcu(&rt->rt.list);
70  	rtnl_unlock();
71  
72  	skb_queue_purge(&rt->pkts);
73  	if (rt->rt.dev)
74  		mctp_dev_put(rt->rt.dev);
75  
76  	refs = refcount_read(&rt->rt.refs);
77  	KUNIT_ASSERT_EQ_MSG(test, refs, 1, "route ref imbalance");
78  
79  	kfree_rcu(&rt->rt, rcu);
80  }
81  
mctp_test_skb_set_dev(struct sk_buff * skb,struct mctp_test_dev * dev)82  static void mctp_test_skb_set_dev(struct sk_buff *skb,
83  				  struct mctp_test_dev *dev)
84  {
85  	struct mctp_skb_cb *cb;
86  
87  	cb = mctp_cb(skb);
88  	cb->net = READ_ONCE(dev->mdev->net);
89  	skb->dev = dev->ndev;
90  }
91  
mctp_test_create_skb(const struct mctp_hdr * hdr,unsigned int data_len)92  static struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr,
93  					    unsigned int data_len)
94  {
95  	size_t hdr_len = sizeof(*hdr);
96  	struct sk_buff *skb;
97  	unsigned int i;
98  	u8 *buf;
99  
100  	skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
101  	if (!skb)
102  		return NULL;
103  
104  	__mctp_cb(skb);
105  	memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
106  
107  	buf = skb_put(skb, data_len);
108  	for (i = 0; i < data_len; i++)
109  		buf[i] = i & 0xff;
110  
111  	return skb;
112  }
113  
__mctp_test_create_skb_data(const struct mctp_hdr * hdr,const void * data,size_t data_len)114  static struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr,
115  						   const void *data,
116  						   size_t data_len)
117  {
118  	size_t hdr_len = sizeof(*hdr);
119  	struct sk_buff *skb;
120  
121  	skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
122  	if (!skb)
123  		return NULL;
124  
125  	__mctp_cb(skb);
126  	memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
127  	memcpy(skb_put(skb, data_len), data, data_len);
128  
129  	return skb;
130  }
131  
132  #define mctp_test_create_skb_data(h, d) \
133  	__mctp_test_create_skb_data(h, d, sizeof(*d))
134  
135  struct mctp_frag_test {
136  	unsigned int mtu;
137  	unsigned int msgsize;
138  	unsigned int n_frags;
139  };
140  
mctp_test_fragment(struct kunit * test)141  static void mctp_test_fragment(struct kunit *test)
142  {
143  	const struct mctp_frag_test *params;
144  	int rc, i, n, mtu, msgsize;
145  	struct mctp_test_route *rt;
146  	struct sk_buff *skb;
147  	struct mctp_hdr hdr;
148  	u8 seq;
149  
150  	params = test->param_value;
151  	mtu = params->mtu;
152  	msgsize = params->msgsize;
153  
154  	hdr.ver = 1;
155  	hdr.src = 8;
156  	hdr.dest = 10;
157  	hdr.flags_seq_tag = MCTP_HDR_FLAG_TO;
158  
159  	skb = mctp_test_create_skb(&hdr, msgsize);
160  	KUNIT_ASSERT_TRUE(test, skb);
161  
162  	rt = mctp_test_create_route(&init_net, NULL, 10, mtu);
163  	KUNIT_ASSERT_TRUE(test, rt);
164  
165  	rc = mctp_do_fragment_route(&rt->rt, skb, mtu, MCTP_TAG_OWNER);
166  	KUNIT_EXPECT_FALSE(test, rc);
167  
168  	n = rt->pkts.qlen;
169  
170  	KUNIT_EXPECT_EQ(test, n, params->n_frags);
171  
172  	for (i = 0;; i++) {
173  		struct mctp_hdr *hdr2;
174  		struct sk_buff *skb2;
175  		u8 tag_mask, seq2;
176  		bool first, last;
177  
178  		first = i == 0;
179  		last = i == (n - 1);
180  
181  		skb2 = skb_dequeue(&rt->pkts);
182  
183  		if (!skb2)
184  			break;
185  
186  		hdr2 = mctp_hdr(skb2);
187  
188  		tag_mask = MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO;
189  
190  		KUNIT_EXPECT_EQ(test, hdr2->ver, hdr.ver);
191  		KUNIT_EXPECT_EQ(test, hdr2->src, hdr.src);
192  		KUNIT_EXPECT_EQ(test, hdr2->dest, hdr.dest);
193  		KUNIT_EXPECT_EQ(test, hdr2->flags_seq_tag & tag_mask,
194  				hdr.flags_seq_tag & tag_mask);
195  
196  		KUNIT_EXPECT_EQ(test,
197  				!!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_SOM), first);
198  		KUNIT_EXPECT_EQ(test,
199  				!!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_EOM), last);
200  
201  		seq2 = (hdr2->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) &
202  			MCTP_HDR_SEQ_MASK;
203  
204  		if (first) {
205  			seq = seq2;
206  		} else {
207  			seq++;
208  			KUNIT_EXPECT_EQ(test, seq2, seq & MCTP_HDR_SEQ_MASK);
209  		}
210  
211  		if (!last)
212  			KUNIT_EXPECT_EQ(test, skb2->len, mtu);
213  		else
214  			KUNIT_EXPECT_LE(test, skb2->len, mtu);
215  
216  		kfree_skb(skb2);
217  	}
218  
219  	mctp_test_route_destroy(test, rt);
220  }
221  
222  static const struct mctp_frag_test mctp_frag_tests[] = {
223  	{.mtu = 68, .msgsize = 63, .n_frags = 1},
224  	{.mtu = 68, .msgsize = 64, .n_frags = 1},
225  	{.mtu = 68, .msgsize = 65, .n_frags = 2},
226  	{.mtu = 68, .msgsize = 66, .n_frags = 2},
227  	{.mtu = 68, .msgsize = 127, .n_frags = 2},
228  	{.mtu = 68, .msgsize = 128, .n_frags = 2},
229  	{.mtu = 68, .msgsize = 129, .n_frags = 3},
230  	{.mtu = 68, .msgsize = 130, .n_frags = 3},
231  };
232  
mctp_frag_test_to_desc(const struct mctp_frag_test * t,char * desc)233  static void mctp_frag_test_to_desc(const struct mctp_frag_test *t, char *desc)
234  {
235  	sprintf(desc, "mtu %d len %d -> %d frags",
236  		t->msgsize, t->mtu, t->n_frags);
237  }
238  
239  KUNIT_ARRAY_PARAM(mctp_frag, mctp_frag_tests, mctp_frag_test_to_desc);
240  
241  struct mctp_rx_input_test {
242  	struct mctp_hdr hdr;
243  	bool input;
244  };
245  
mctp_test_rx_input(struct kunit * test)246  static void mctp_test_rx_input(struct kunit *test)
247  {
248  	const struct mctp_rx_input_test *params;
249  	struct mctp_test_route *rt;
250  	struct mctp_test_dev *dev;
251  	struct sk_buff *skb;
252  
253  	params = test->param_value;
254  
255  	dev = mctp_test_create_dev();
256  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
257  
258  	rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
259  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
260  
261  	skb = mctp_test_create_skb(&params->hdr, 1);
262  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
263  
264  	mctp_pkttype_receive(skb, dev->ndev, &mctp_packet_type, NULL);
265  
266  	KUNIT_EXPECT_EQ(test, !!rt->pkts.qlen, params->input);
267  
268  	mctp_test_route_destroy(test, rt);
269  	mctp_test_destroy_dev(dev);
270  }
271  
272  #define RX_HDR(_ver, _src, _dest, _fst) \
273  	{ .ver = _ver, .src = _src, .dest = _dest, .flags_seq_tag = _fst }
274  
275  /* we have a route for EID 8 only */
276  static const struct mctp_rx_input_test mctp_rx_input_tests[] = {
277  	{ .hdr = RX_HDR(1, 10, 8, 0), .input = true },
278  	{ .hdr = RX_HDR(1, 10, 9, 0), .input = false }, /* no input route */
279  	{ .hdr = RX_HDR(2, 10, 8, 0), .input = false }, /* invalid version */
280  };
281  
mctp_rx_input_test_to_desc(const struct mctp_rx_input_test * t,char * desc)282  static void mctp_rx_input_test_to_desc(const struct mctp_rx_input_test *t,
283  				       char *desc)
284  {
285  	sprintf(desc, "{%x,%x,%x,%x}", t->hdr.ver, t->hdr.src, t->hdr.dest,
286  		t->hdr.flags_seq_tag);
287  }
288  
289  KUNIT_ARRAY_PARAM(mctp_rx_input, mctp_rx_input_tests,
290  		  mctp_rx_input_test_to_desc);
291  
292  /* set up a local dev, route on EID 8, and a socket listening on type 0 */
__mctp_route_test_init(struct kunit * test,struct mctp_test_dev ** devp,struct mctp_test_route ** rtp,struct socket ** sockp,unsigned int netid)293  static void __mctp_route_test_init(struct kunit *test,
294  				   struct mctp_test_dev **devp,
295  				   struct mctp_test_route **rtp,
296  				   struct socket **sockp,
297  				   unsigned int netid)
298  {
299  	struct sockaddr_mctp addr = {0};
300  	struct mctp_test_route *rt;
301  	struct mctp_test_dev *dev;
302  	struct socket *sock;
303  	int rc;
304  
305  	dev = mctp_test_create_dev();
306  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
307  	if (netid != MCTP_NET_ANY)
308  		WRITE_ONCE(dev->mdev->net, netid);
309  
310  	rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
311  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
312  
313  	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
314  	KUNIT_ASSERT_EQ(test, rc, 0);
315  
316  	addr.smctp_family = AF_MCTP;
317  	addr.smctp_network = netid;
318  	addr.smctp_addr.s_addr = 8;
319  	addr.smctp_type = 0;
320  	rc = kernel_bind(sock, (struct sockaddr *)&addr, sizeof(addr));
321  	KUNIT_ASSERT_EQ(test, rc, 0);
322  
323  	*rtp = rt;
324  	*devp = dev;
325  	*sockp = sock;
326  }
327  
__mctp_route_test_fini(struct kunit * test,struct mctp_test_dev * dev,struct mctp_test_route * rt,struct socket * sock)328  static void __mctp_route_test_fini(struct kunit *test,
329  				   struct mctp_test_dev *dev,
330  				   struct mctp_test_route *rt,
331  				   struct socket *sock)
332  {
333  	sock_release(sock);
334  	mctp_test_route_destroy(test, rt);
335  	mctp_test_destroy_dev(dev);
336  }
337  
338  struct mctp_route_input_sk_test {
339  	struct mctp_hdr hdr;
340  	u8 type;
341  	bool deliver;
342  };
343  
mctp_test_route_input_sk(struct kunit * test)344  static void mctp_test_route_input_sk(struct kunit *test)
345  {
346  	const struct mctp_route_input_sk_test *params;
347  	struct sk_buff *skb, *skb2;
348  	struct mctp_test_route *rt;
349  	struct mctp_test_dev *dev;
350  	struct socket *sock;
351  	int rc;
352  
353  	params = test->param_value;
354  
355  	__mctp_route_test_init(test, &dev, &rt, &sock, MCTP_NET_ANY);
356  
357  	skb = mctp_test_create_skb_data(&params->hdr, &params->type);
358  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
359  
360  	mctp_test_skb_set_dev(skb, dev);
361  
362  	rc = mctp_route_input(&rt->rt, skb);
363  
364  	if (params->deliver) {
365  		KUNIT_EXPECT_EQ(test, rc, 0);
366  
367  		skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
368  		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
369  		KUNIT_EXPECT_EQ(test, skb2->len, 1);
370  
371  		skb_free_datagram(sock->sk, skb2);
372  
373  	} else {
374  		KUNIT_EXPECT_NE(test, rc, 0);
375  		skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
376  		KUNIT_EXPECT_NULL(test, skb2);
377  	}
378  
379  	__mctp_route_test_fini(test, dev, rt, sock);
380  }
381  
382  #define FL_S	(MCTP_HDR_FLAG_SOM)
383  #define FL_E	(MCTP_HDR_FLAG_EOM)
384  #define FL_TO	(MCTP_HDR_FLAG_TO)
385  #define FL_T(t)	((t) & MCTP_HDR_TAG_MASK)
386  
387  static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = {
388  	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true },
389  	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false },
390  	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false },
391  	{ .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false },
392  	{ .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false },
393  	{ .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false },
394  };
395  
mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test * t,char * desc)396  static void mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test *t,
397  					char *desc)
398  {
399  	sprintf(desc, "{%x,%x,%x,%x} type %d", t->hdr.ver, t->hdr.src,
400  		t->hdr.dest, t->hdr.flags_seq_tag, t->type);
401  }
402  
403  KUNIT_ARRAY_PARAM(mctp_route_input_sk, mctp_route_input_sk_tests,
404  		  mctp_route_input_sk_to_desc);
405  
406  struct mctp_route_input_sk_reasm_test {
407  	const char *name;
408  	struct mctp_hdr hdrs[4];
409  	int n_hdrs;
410  	int rx_len;
411  };
412  
mctp_test_route_input_sk_reasm(struct kunit * test)413  static void mctp_test_route_input_sk_reasm(struct kunit *test)
414  {
415  	const struct mctp_route_input_sk_reasm_test *params;
416  	struct sk_buff *skb, *skb2;
417  	struct mctp_test_route *rt;
418  	struct mctp_test_dev *dev;
419  	struct socket *sock;
420  	int i, rc;
421  	u8 c;
422  
423  	params = test->param_value;
424  
425  	__mctp_route_test_init(test, &dev, &rt, &sock, MCTP_NET_ANY);
426  
427  	for (i = 0; i < params->n_hdrs; i++) {
428  		c = i;
429  		skb = mctp_test_create_skb_data(&params->hdrs[i], &c);
430  		KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
431  
432  		mctp_test_skb_set_dev(skb, dev);
433  
434  		rc = mctp_route_input(&rt->rt, skb);
435  	}
436  
437  	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
438  
439  	if (params->rx_len) {
440  		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
441  		KUNIT_EXPECT_EQ(test, skb2->len, params->rx_len);
442  		skb_free_datagram(sock->sk, skb2);
443  
444  	} else {
445  		KUNIT_EXPECT_NULL(test, skb2);
446  	}
447  
448  	__mctp_route_test_fini(test, dev, rt, sock);
449  }
450  
451  #define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
452  
453  static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = {
454  	{
455  		.name = "single packet",
456  		.hdrs = {
457  			RX_FRAG(FL_S | FL_E, 0),
458  		},
459  		.n_hdrs = 1,
460  		.rx_len = 1,
461  	},
462  	{
463  		.name = "single packet, offset seq",
464  		.hdrs = {
465  			RX_FRAG(FL_S | FL_E, 1),
466  		},
467  		.n_hdrs = 1,
468  		.rx_len = 1,
469  	},
470  	{
471  		.name = "start & end packets",
472  		.hdrs = {
473  			RX_FRAG(FL_S, 0),
474  			RX_FRAG(FL_E, 1),
475  		},
476  		.n_hdrs = 2,
477  		.rx_len = 2,
478  	},
479  	{
480  		.name = "start & end packets, offset seq",
481  		.hdrs = {
482  			RX_FRAG(FL_S, 1),
483  			RX_FRAG(FL_E, 2),
484  		},
485  		.n_hdrs = 2,
486  		.rx_len = 2,
487  	},
488  	{
489  		.name = "start & end packets, out of order",
490  		.hdrs = {
491  			RX_FRAG(FL_E, 1),
492  			RX_FRAG(FL_S, 0),
493  		},
494  		.n_hdrs = 2,
495  		.rx_len = 0,
496  	},
497  	{
498  		.name = "start, middle & end packets",
499  		.hdrs = {
500  			RX_FRAG(FL_S, 0),
501  			RX_FRAG(0,    1),
502  			RX_FRAG(FL_E, 2),
503  		},
504  		.n_hdrs = 3,
505  		.rx_len = 3,
506  	},
507  	{
508  		.name = "missing seq",
509  		.hdrs = {
510  			RX_FRAG(FL_S, 0),
511  			RX_FRAG(FL_E, 2),
512  		},
513  		.n_hdrs = 2,
514  		.rx_len = 0,
515  	},
516  	{
517  		.name = "seq wrap",
518  		.hdrs = {
519  			RX_FRAG(FL_S, 3),
520  			RX_FRAG(FL_E, 0),
521  		},
522  		.n_hdrs = 2,
523  		.rx_len = 2,
524  	},
525  };
526  
mctp_route_input_sk_reasm_to_desc(const struct mctp_route_input_sk_reasm_test * t,char * desc)527  static void mctp_route_input_sk_reasm_to_desc(
528  				const struct mctp_route_input_sk_reasm_test *t,
529  				char *desc)
530  {
531  	sprintf(desc, "%s", t->name);
532  }
533  
534  KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests,
535  		  mctp_route_input_sk_reasm_to_desc);
536  
537  struct mctp_route_input_sk_keys_test {
538  	const char	*name;
539  	mctp_eid_t	key_peer_addr;
540  	mctp_eid_t	key_local_addr;
541  	u8		key_tag;
542  	struct mctp_hdr hdr;
543  	bool		deliver;
544  };
545  
546  /* test packet rx in the presence of various key configurations */
mctp_test_route_input_sk_keys(struct kunit * test)547  static void mctp_test_route_input_sk_keys(struct kunit *test)
548  {
549  	const struct mctp_route_input_sk_keys_test *params;
550  	struct mctp_test_route *rt;
551  	struct sk_buff *skb, *skb2;
552  	struct mctp_test_dev *dev;
553  	struct mctp_sk_key *key;
554  	struct netns_mctp *mns;
555  	struct mctp_sock *msk;
556  	struct socket *sock;
557  	unsigned long flags;
558  	unsigned int net;
559  	int rc;
560  	u8 c;
561  
562  	params = test->param_value;
563  
564  	dev = mctp_test_create_dev();
565  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
566  	net = READ_ONCE(dev->mdev->net);
567  
568  	rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
569  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
570  
571  	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
572  	KUNIT_ASSERT_EQ(test, rc, 0);
573  
574  	msk = container_of(sock->sk, struct mctp_sock, sk);
575  	mns = &sock_net(sock->sk)->mctp;
576  
577  	/* set the incoming tag according to test params */
578  	key = mctp_key_alloc(msk, net, params->key_local_addr,
579  			     params->key_peer_addr, params->key_tag,
580  			     GFP_KERNEL);
581  
582  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);
583  
584  	spin_lock_irqsave(&mns->keys_lock, flags);
585  	mctp_reserve_tag(&init_net, key, msk);
586  	spin_unlock_irqrestore(&mns->keys_lock, flags);
587  
588  	/* create packet and route */
589  	c = 0;
590  	skb = mctp_test_create_skb_data(&params->hdr, &c);
591  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
592  
593  	mctp_test_skb_set_dev(skb, dev);
594  
595  	rc = mctp_route_input(&rt->rt, skb);
596  
597  	/* (potentially) receive message */
598  	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
599  
600  	if (params->deliver)
601  		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
602  	else
603  		KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
604  
605  	if (skb2)
606  		skb_free_datagram(sock->sk, skb2);
607  
608  	mctp_key_unref(key);
609  	__mctp_route_test_fini(test, dev, rt, sock);
610  }
611  
612  static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = {
613  	{
614  		.name = "direct match",
615  		.key_peer_addr = 9,
616  		.key_local_addr = 8,
617  		.key_tag = 1,
618  		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
619  		.deliver = true,
620  	},
621  	{
622  		.name = "flipped src/dest",
623  		.key_peer_addr = 8,
624  		.key_local_addr = 9,
625  		.key_tag = 1,
626  		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
627  		.deliver = false,
628  	},
629  	{
630  		.name = "peer addr mismatch",
631  		.key_peer_addr = 9,
632  		.key_local_addr = 8,
633  		.key_tag = 1,
634  		.hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)),
635  		.deliver = false,
636  	},
637  	{
638  		.name = "tag value mismatch",
639  		.key_peer_addr = 9,
640  		.key_local_addr = 8,
641  		.key_tag = 1,
642  		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)),
643  		.deliver = false,
644  	},
645  	{
646  		.name = "TO mismatch",
647  		.key_peer_addr = 9,
648  		.key_local_addr = 8,
649  		.key_tag = 1,
650  		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO),
651  		.deliver = false,
652  	},
653  	{
654  		.name = "broadcast response",
655  		.key_peer_addr = MCTP_ADDR_ANY,
656  		.key_local_addr = 8,
657  		.key_tag = 1,
658  		.hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)),
659  		.deliver = true,
660  	},
661  	{
662  		.name = "any local match",
663  		.key_peer_addr = 12,
664  		.key_local_addr = MCTP_ADDR_ANY,
665  		.key_tag = 1,
666  		.hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)),
667  		.deliver = true,
668  	},
669  };
670  
mctp_route_input_sk_keys_to_desc(const struct mctp_route_input_sk_keys_test * t,char * desc)671  static void mctp_route_input_sk_keys_to_desc(
672  				const struct mctp_route_input_sk_keys_test *t,
673  				char *desc)
674  {
675  	sprintf(desc, "%s", t->name);
676  }
677  
678  KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests,
679  		  mctp_route_input_sk_keys_to_desc);
680  
681  struct test_net {
682  	unsigned int netid;
683  	struct mctp_test_dev *dev;
684  	struct mctp_test_route *rt;
685  	struct socket *sock;
686  	struct sk_buff *skb;
687  	struct mctp_sk_key *key;
688  	struct {
689  		u8 type;
690  		unsigned int data;
691  	} msg;
692  };
693  
694  static void
mctp_test_route_input_multiple_nets_bind_init(struct kunit * test,struct test_net * t)695  mctp_test_route_input_multiple_nets_bind_init(struct kunit *test,
696  					      struct test_net *t)
697  {
698  	struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO);
699  
700  	t->msg.data = t->netid;
701  
702  	__mctp_route_test_init(test, &t->dev, &t->rt, &t->sock, t->netid);
703  
704  	t->skb = mctp_test_create_skb_data(&hdr, &t->msg);
705  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb);
706  	mctp_test_skb_set_dev(t->skb, t->dev);
707  }
708  
709  static void
mctp_test_route_input_multiple_nets_bind_fini(struct kunit * test,struct test_net * t)710  mctp_test_route_input_multiple_nets_bind_fini(struct kunit *test,
711  					      struct test_net *t)
712  {
713  	__mctp_route_test_fini(test, t->dev, t->rt, t->sock);
714  }
715  
716  /* Test that skbs from different nets (otherwise identical) get routed to their
717   * corresponding socket via the sockets' bind()
718   */
mctp_test_route_input_multiple_nets_bind(struct kunit * test)719  static void mctp_test_route_input_multiple_nets_bind(struct kunit *test)
720  {
721  	struct sk_buff *rx_skb1, *rx_skb2;
722  	struct test_net t1, t2;
723  	int rc;
724  
725  	t1.netid = 1;
726  	t2.netid = 2;
727  
728  	t1.msg.type = 0;
729  	t2.msg.type = 0;
730  
731  	mctp_test_route_input_multiple_nets_bind_init(test, &t1);
732  	mctp_test_route_input_multiple_nets_bind_init(test, &t2);
733  
734  	rc = mctp_route_input(&t1.rt->rt, t1.skb);
735  	KUNIT_ASSERT_EQ(test, rc, 0);
736  	rc = mctp_route_input(&t2.rt->rt, t2.skb);
737  	KUNIT_ASSERT_EQ(test, rc, 0);
738  
739  	rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc);
740  	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1);
741  	KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg));
742  	KUNIT_EXPECT_EQ(test,
743  			*(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)),
744  			t1.netid);
745  	kfree_skb(rx_skb1);
746  
747  	rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc);
748  	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2);
749  	KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg));
750  	KUNIT_EXPECT_EQ(test,
751  			*(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)),
752  			t2.netid);
753  	kfree_skb(rx_skb2);
754  
755  	mctp_test_route_input_multiple_nets_bind_fini(test, &t1);
756  	mctp_test_route_input_multiple_nets_bind_fini(test, &t2);
757  }
758  
759  static void
mctp_test_route_input_multiple_nets_key_init(struct kunit * test,struct test_net * t)760  mctp_test_route_input_multiple_nets_key_init(struct kunit *test,
761  					     struct test_net *t)
762  {
763  	struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1));
764  	struct mctp_sock *msk;
765  	struct netns_mctp *mns;
766  	unsigned long flags;
767  
768  	t->msg.data = t->netid;
769  
770  	__mctp_route_test_init(test, &t->dev, &t->rt, &t->sock, t->netid);
771  
772  	msk = container_of(t->sock->sk, struct mctp_sock, sk);
773  
774  	t->key = mctp_key_alloc(msk, t->netid, hdr.dest, hdr.src, 1, GFP_KERNEL);
775  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key);
776  
777  	mns = &sock_net(t->sock->sk)->mctp;
778  	spin_lock_irqsave(&mns->keys_lock, flags);
779  	mctp_reserve_tag(&init_net, t->key, msk);
780  	spin_unlock_irqrestore(&mns->keys_lock, flags);
781  
782  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key);
783  	t->skb = mctp_test_create_skb_data(&hdr, &t->msg);
784  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb);
785  	mctp_test_skb_set_dev(t->skb, t->dev);
786  }
787  
788  static void
mctp_test_route_input_multiple_nets_key_fini(struct kunit * test,struct test_net * t)789  mctp_test_route_input_multiple_nets_key_fini(struct kunit *test,
790  					     struct test_net *t)
791  {
792  	mctp_key_unref(t->key);
793  	__mctp_route_test_fini(test, t->dev, t->rt, t->sock);
794  }
795  
796  /* test that skbs from different nets (otherwise identical) get routed to their
797   * corresponding socket via the sk_key
798   */
mctp_test_route_input_multiple_nets_key(struct kunit * test)799  static void mctp_test_route_input_multiple_nets_key(struct kunit *test)
800  {
801  	struct sk_buff *rx_skb1, *rx_skb2;
802  	struct test_net t1, t2;
803  	int rc;
804  
805  	t1.netid = 1;
806  	t2.netid = 2;
807  
808  	/* use type 1 which is not bound */
809  	t1.msg.type = 1;
810  	t2.msg.type = 1;
811  
812  	mctp_test_route_input_multiple_nets_key_init(test, &t1);
813  	mctp_test_route_input_multiple_nets_key_init(test, &t2);
814  
815  	rc = mctp_route_input(&t1.rt->rt, t1.skb);
816  	KUNIT_ASSERT_EQ(test, rc, 0);
817  	rc = mctp_route_input(&t2.rt->rt, t2.skb);
818  	KUNIT_ASSERT_EQ(test, rc, 0);
819  
820  	rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc);
821  	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1);
822  	KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg));
823  	KUNIT_EXPECT_EQ(test,
824  			*(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)),
825  			t1.netid);
826  	kfree_skb(rx_skb1);
827  
828  	rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc);
829  	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2);
830  	KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg));
831  	KUNIT_EXPECT_EQ(test,
832  			*(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)),
833  			t2.netid);
834  	kfree_skb(rx_skb2);
835  
836  	mctp_test_route_input_multiple_nets_key_fini(test, &t1);
837  	mctp_test_route_input_multiple_nets_key_fini(test, &t2);
838  }
839  
840  #if IS_ENABLED(CONFIG_MCTP_FLOWS)
841  
mctp_test_flow_init(struct kunit * test,struct mctp_test_dev ** devp,struct mctp_test_route ** rtp,struct socket ** sock,struct sk_buff ** skbp,unsigned int len)842  static void mctp_test_flow_init(struct kunit *test,
843  				struct mctp_test_dev **devp,
844  				struct mctp_test_route **rtp,
845  				struct socket **sock,
846  				struct sk_buff **skbp,
847  				unsigned int len)
848  {
849  	struct mctp_test_route *rt;
850  	struct mctp_test_dev *dev;
851  	struct sk_buff *skb;
852  
853  	/* we have a slightly odd routing setup here; the test route
854  	 * is for EID 8, which is our local EID. We don't do a routing
855  	 * lookup, so that's fine - all we require is a path through
856  	 * mctp_local_output, which will call rt->output on whatever
857  	 * route we provide
858  	 */
859  	__mctp_route_test_init(test, &dev, &rt, sock, MCTP_NET_ANY);
860  
861  	/* Assign a single EID. ->addrs is freed on mctp netdev release */
862  	dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL);
863  	dev->mdev->num_addrs = 1;
864  	dev->mdev->addrs[0] = 8;
865  
866  	skb = alloc_skb(len + sizeof(struct mctp_hdr) + 1, GFP_KERNEL);
867  	KUNIT_ASSERT_TRUE(test, skb);
868  	__mctp_cb(skb);
869  	skb_reserve(skb, sizeof(struct mctp_hdr) + 1);
870  	memset(skb_put(skb, len), 0, len);
871  
872  	/* take a ref for the route, we'll decrement in local output */
873  	refcount_inc(&rt->rt.refs);
874  
875  	*devp = dev;
876  	*rtp = rt;
877  	*skbp = skb;
878  }
879  
mctp_test_flow_fini(struct kunit * test,struct mctp_test_dev * dev,struct mctp_test_route * rt,struct socket * sock)880  static void mctp_test_flow_fini(struct kunit *test,
881  				struct mctp_test_dev *dev,
882  				struct mctp_test_route *rt,
883  				struct socket *sock)
884  {
885  	__mctp_route_test_fini(test, dev, rt, sock);
886  }
887  
888  /* test that an outgoing skb has the correct MCTP extension data set */
mctp_test_packet_flow(struct kunit * test)889  static void mctp_test_packet_flow(struct kunit *test)
890  {
891  	struct sk_buff *skb, *skb2;
892  	struct mctp_test_route *rt;
893  	struct mctp_test_dev *dev;
894  	struct mctp_flow *flow;
895  	struct socket *sock;
896  	u8 dst = 8;
897  	int n, rc;
898  
899  	mctp_test_flow_init(test, &dev, &rt, &sock, &skb, 30);
900  
901  	rc = mctp_local_output(sock->sk, &rt->rt, skb, dst, MCTP_TAG_OWNER);
902  	KUNIT_ASSERT_EQ(test, rc, 0);
903  
904  	n = rt->pkts.qlen;
905  	KUNIT_ASSERT_EQ(test, n, 1);
906  
907  	skb2 = skb_dequeue(&rt->pkts);
908  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
909  
910  	flow = skb_ext_find(skb2, SKB_EXT_MCTP);
911  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow);
912  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow->key);
913  	KUNIT_ASSERT_PTR_EQ(test, flow->key->sk, sock->sk);
914  
915  	kfree_skb(skb2);
916  	mctp_test_flow_fini(test, dev, rt, sock);
917  }
918  
919  /* test that outgoing skbs, after fragmentation, all have the correct MCTP
920   * extension data set.
921   */
mctp_test_fragment_flow(struct kunit * test)922  static void mctp_test_fragment_flow(struct kunit *test)
923  {
924  	struct mctp_flow *flows[2];
925  	struct sk_buff *tx_skbs[2];
926  	struct mctp_test_route *rt;
927  	struct mctp_test_dev *dev;
928  	struct sk_buff *skb;
929  	struct socket *sock;
930  	u8 dst = 8;
931  	int n, rc;
932  
933  	mctp_test_flow_init(test, &dev, &rt, &sock, &skb, 100);
934  
935  	rc = mctp_local_output(sock->sk, &rt->rt, skb, dst, MCTP_TAG_OWNER);
936  	KUNIT_ASSERT_EQ(test, rc, 0);
937  
938  	n = rt->pkts.qlen;
939  	KUNIT_ASSERT_EQ(test, n, 2);
940  
941  	/* both resulting packets should have the same flow data */
942  	tx_skbs[0] = skb_dequeue(&rt->pkts);
943  	tx_skbs[1] = skb_dequeue(&rt->pkts);
944  
945  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[0]);
946  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[1]);
947  
948  	flows[0] = skb_ext_find(tx_skbs[0], SKB_EXT_MCTP);
949  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]);
950  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]->key);
951  	KUNIT_ASSERT_PTR_EQ(test, flows[0]->key->sk, sock->sk);
952  
953  	flows[1] = skb_ext_find(tx_skbs[1], SKB_EXT_MCTP);
954  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[1]);
955  	KUNIT_ASSERT_PTR_EQ(test, flows[1]->key, flows[0]->key);
956  
957  	kfree_skb(tx_skbs[0]);
958  	kfree_skb(tx_skbs[1]);
959  	mctp_test_flow_fini(test, dev, rt, sock);
960  }
961  
962  #else
mctp_test_packet_flow(struct kunit * test)963  static void mctp_test_packet_flow(struct kunit *test)
964  {
965  	kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
966  }
967  
mctp_test_fragment_flow(struct kunit * test)968  static void mctp_test_fragment_flow(struct kunit *test)
969  {
970  	kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
971  }
972  #endif
973  
974  /* Test that outgoing skbs cause a suitable tag to be created */
mctp_test_route_output_key_create(struct kunit * test)975  static void mctp_test_route_output_key_create(struct kunit *test)
976  {
977  	const unsigned int netid = 50;
978  	const u8 dst = 26, src = 15;
979  	struct mctp_test_route *rt;
980  	struct mctp_test_dev *dev;
981  	struct mctp_sk_key *key;
982  	struct netns_mctp *mns;
983  	unsigned long flags;
984  	struct socket *sock;
985  	struct sk_buff *skb;
986  	bool empty, single;
987  	const int len = 2;
988  	int rc;
989  
990  	dev = mctp_test_create_dev();
991  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
992  	WRITE_ONCE(dev->mdev->net, netid);
993  
994  	rt = mctp_test_create_route(&init_net, dev->mdev, dst, 68);
995  	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
996  
997  	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
998  	KUNIT_ASSERT_EQ(test, rc, 0);
999  
1000  	dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL);
1001  	dev->mdev->num_addrs = 1;
1002  	dev->mdev->addrs[0] = src;
1003  
1004  	skb = alloc_skb(sizeof(struct mctp_hdr) + 1 + len, GFP_KERNEL);
1005  	KUNIT_ASSERT_TRUE(test, skb);
1006  	__mctp_cb(skb);
1007  	skb_reserve(skb, sizeof(struct mctp_hdr) + 1 + len);
1008  	memset(skb_put(skb, len), 0, len);
1009  
1010  	refcount_inc(&rt->rt.refs);
1011  
1012  	mns = &sock_net(sock->sk)->mctp;
1013  
1014  	/* We assume we're starting from an empty keys list, which requires
1015  	 * preceding tests to clean up correctly!
1016  	 */
1017  	spin_lock_irqsave(&mns->keys_lock, flags);
1018  	empty = hlist_empty(&mns->keys);
1019  	spin_unlock_irqrestore(&mns->keys_lock, flags);
1020  	KUNIT_ASSERT_TRUE(test, empty);
1021  
1022  	rc = mctp_local_output(sock->sk, &rt->rt, skb, dst, MCTP_TAG_OWNER);
1023  	KUNIT_ASSERT_EQ(test, rc, 0);
1024  
1025  	key = NULL;
1026  	single = false;
1027  	spin_lock_irqsave(&mns->keys_lock, flags);
1028  	if (!hlist_empty(&mns->keys)) {
1029  		key = hlist_entry(mns->keys.first, struct mctp_sk_key, hlist);
1030  		single = hlist_is_singular_node(&key->hlist, &mns->keys);
1031  	}
1032  	spin_unlock_irqrestore(&mns->keys_lock, flags);
1033  
1034  	KUNIT_ASSERT_NOT_NULL(test, key);
1035  	KUNIT_ASSERT_TRUE(test, single);
1036  
1037  	KUNIT_EXPECT_EQ(test, key->net, netid);
1038  	KUNIT_EXPECT_EQ(test, key->local_addr, src);
1039  	KUNIT_EXPECT_EQ(test, key->peer_addr, dst);
1040  	/* key has incoming tag, so inverse of what we sent */
1041  	KUNIT_EXPECT_FALSE(test, key->tag & MCTP_TAG_OWNER);
1042  
1043  	sock_release(sock);
1044  	mctp_test_route_destroy(test, rt);
1045  	mctp_test_destroy_dev(dev);
1046  }
1047  
1048  static struct kunit_case mctp_test_cases[] = {
1049  	KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
1050  	KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
1051  	KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params),
1052  	KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm,
1053  			 mctp_route_input_sk_reasm_gen_params),
1054  	KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys,
1055  			 mctp_route_input_sk_keys_gen_params),
1056  	KUNIT_CASE(mctp_test_route_input_multiple_nets_bind),
1057  	KUNIT_CASE(mctp_test_route_input_multiple_nets_key),
1058  	KUNIT_CASE(mctp_test_packet_flow),
1059  	KUNIT_CASE(mctp_test_fragment_flow),
1060  	KUNIT_CASE(mctp_test_route_output_key_create),
1061  	{}
1062  };
1063  
1064  static struct kunit_suite mctp_test_suite = {
1065  	.name = "mctp",
1066  	.test_cases = mctp_test_cases,
1067  };
1068  
1069  kunit_test_suite(mctp_test_suite);
1070