1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Server address list management
3  *
4  * Copyright (C) 2017 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7 
8 #include <linux/slab.h>
9 #include <linux/ctype.h>
10 #include <linux/dns_resolver.h>
11 #include <linux/inet.h>
12 #include <keys/rxrpc-type.h>
13 #include "internal.h"
14 #include "afs_fs.h"
15 
afs_free_addrlist(struct rcu_head * rcu)16 static void afs_free_addrlist(struct rcu_head *rcu)
17 {
18 	struct afs_addr_list *alist = container_of(rcu, struct afs_addr_list, rcu);
19 	unsigned int i;
20 
21 	for (i = 0; i < alist->nr_addrs; i++)
22 		rxrpc_kernel_put_peer(alist->addrs[i].peer);
23 	trace_afs_alist(alist->debug_id, refcount_read(&alist->usage), afs_alist_trace_free);
24 	kfree(alist);
25 }
26 
27 /*
28  * Release an address list.
29  */
afs_put_addrlist(struct afs_addr_list * alist,enum afs_alist_trace reason)30 void afs_put_addrlist(struct afs_addr_list *alist, enum afs_alist_trace reason)
31 {
32 	unsigned int debug_id;
33 	bool dead;
34 	int r;
35 
36 	if (!alist)
37 		return;
38 	debug_id = alist->debug_id;
39 	dead = __refcount_dec_and_test(&alist->usage, &r);
40 	trace_afs_alist(debug_id, r - 1, reason);
41 	if (dead)
42 		call_rcu(&alist->rcu, afs_free_addrlist);
43 }
44 
afs_get_addrlist(struct afs_addr_list * alist,enum afs_alist_trace reason)45 struct afs_addr_list *afs_get_addrlist(struct afs_addr_list *alist, enum afs_alist_trace reason)
46 {
47 	int r;
48 
49 	if (alist) {
50 		__refcount_inc(&alist->usage, &r);
51 		trace_afs_alist(alist->debug_id, r + 1, reason);
52 	}
53 	return alist;
54 }
55 
56 /*
57  * Allocate an address list.
58  */
afs_alloc_addrlist(unsigned int nr)59 struct afs_addr_list *afs_alloc_addrlist(unsigned int nr)
60 {
61 	struct afs_addr_list *alist;
62 	static atomic_t debug_id;
63 
64 	_enter("%u", nr);
65 
66 	if (nr > AFS_MAX_ADDRESSES)
67 		nr = AFS_MAX_ADDRESSES;
68 
69 	alist = kzalloc(struct_size(alist, addrs, nr), GFP_KERNEL);
70 	if (!alist)
71 		return NULL;
72 
73 	refcount_set(&alist->usage, 1);
74 	alist->max_addrs = nr;
75 	alist->debug_id = atomic_inc_return(&debug_id);
76 	trace_afs_alist(alist->debug_id, 1, afs_alist_trace_alloc);
77 	return alist;
78 }
79 
80 /*
81  * Parse a text string consisting of delimited addresses.
82  */
afs_parse_text_addrs(struct afs_net * net,const char * text,size_t len,char delim,unsigned short service,unsigned short port)83 struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
84 					       const char *text, size_t len,
85 					       char delim,
86 					       unsigned short service,
87 					       unsigned short port)
88 {
89 	struct afs_vlserver_list *vllist;
90 	struct afs_addr_list *alist;
91 	const char *p, *end = text + len;
92 	const char *problem;
93 	unsigned int nr = 0;
94 	int ret = -ENOMEM;
95 
96 	_enter("%*.*s,%c", (int)len, (int)len, text, delim);
97 
98 	if (!len) {
99 		_leave(" = -EDESTADDRREQ [empty]");
100 		return ERR_PTR(-EDESTADDRREQ);
101 	}
102 
103 	if (delim == ':' && (memchr(text, ',', len) || !memchr(text, '.', len)))
104 		delim = ',';
105 
106 	/* Count the addresses */
107 	p = text;
108 	do {
109 		if (!*p) {
110 			problem = "nul";
111 			goto inval;
112 		}
113 		if (*p == delim)
114 			continue;
115 		nr++;
116 		if (*p == '[') {
117 			p++;
118 			if (p == end) {
119 				problem = "brace1";
120 				goto inval;
121 			}
122 			p = memchr(p, ']', end - p);
123 			if (!p) {
124 				problem = "brace2";
125 				goto inval;
126 			}
127 			p++;
128 			if (p >= end)
129 				break;
130 		}
131 
132 		p = memchr(p, delim, end - p);
133 		if (!p)
134 			break;
135 		p++;
136 	} while (p < end);
137 
138 	_debug("%u/%u addresses", nr, AFS_MAX_ADDRESSES);
139 
140 	vllist = afs_alloc_vlserver_list(1);
141 	if (!vllist)
142 		return ERR_PTR(-ENOMEM);
143 
144 	vllist->nr_servers = 1;
145 	vllist->servers[0].server = afs_alloc_vlserver("<dummy>", 7, AFS_VL_PORT);
146 	if (!vllist->servers[0].server)
147 		goto error_vl;
148 
149 	alist = afs_alloc_addrlist(nr);
150 	if (!alist)
151 		goto error;
152 
153 	/* Extract the addresses */
154 	p = text;
155 	do {
156 		const char *q, *stop;
157 		unsigned int xport = port;
158 		__be32 x[4];
159 		int family;
160 
161 		if (*p == delim) {
162 			p++;
163 			continue;
164 		}
165 
166 		if (*p == '[') {
167 			p++;
168 			q = memchr(p, ']', end - p);
169 		} else {
170 			for (q = p; q < end; q++)
171 				if (*q == '+' || *q == delim)
172 					break;
173 		}
174 
175 		if (in4_pton(p, q - p, (u8 *)&x[0], -1, &stop)) {
176 			family = AF_INET;
177 		} else if (in6_pton(p, q - p, (u8 *)x, -1, &stop)) {
178 			family = AF_INET6;
179 		} else {
180 			problem = "family";
181 			goto bad_address;
182 		}
183 
184 		p = q;
185 		if (stop != p) {
186 			problem = "nostop";
187 			goto bad_address;
188 		}
189 
190 		if (q < end && *q == ']')
191 			p++;
192 
193 		if (p < end) {
194 			if (*p == '+') {
195 				/* Port number specification "+1234" */
196 				xport = 0;
197 				p++;
198 				if (p >= end || !isdigit(*p)) {
199 					problem = "port";
200 					goto bad_address;
201 				}
202 				do {
203 					xport *= 10;
204 					xport += *p - '0';
205 					if (xport > 65535) {
206 						problem = "pval";
207 						goto bad_address;
208 					}
209 					p++;
210 				} while (p < end && isdigit(*p));
211 			} else if (*p == delim) {
212 				p++;
213 			} else {
214 				problem = "weird";
215 				goto bad_address;
216 			}
217 		}
218 
219 		if (family == AF_INET)
220 			ret = afs_merge_fs_addr4(net, alist, x[0], xport);
221 		else
222 			ret = afs_merge_fs_addr6(net, alist, x, xport);
223 		if (ret < 0)
224 			goto error;
225 
226 	} while (p < end);
227 
228 	rcu_assign_pointer(vllist->servers[0].server->addresses, alist);
229 	_leave(" = [nr %u]", alist->nr_addrs);
230 	return vllist;
231 
232 inval:
233 	_leave(" = -EINVAL [%s %zu %*.*s]",
234 	       problem, p - text, (int)len, (int)len, text);
235 	return ERR_PTR(-EINVAL);
236 bad_address:
237 	_leave(" = -EINVAL [%s %zu %*.*s]",
238 	       problem, p - text, (int)len, (int)len, text);
239 	ret = -EINVAL;
240 error:
241 	afs_put_addrlist(alist, afs_alist_trace_put_parse_error);
242 error_vl:
243 	afs_put_vlserverlist(net, vllist);
244 	return ERR_PTR(ret);
245 }
246 
247 /*
248  * Perform a DNS query for VL servers and build a up an address list.
249  */
afs_dns_query(struct afs_cell * cell,time64_t * _expiry)250 struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry)
251 {
252 	struct afs_vlserver_list *vllist;
253 	char *result = NULL;
254 	int ret;
255 
256 	_enter("%s", cell->name);
257 
258 	ret = dns_query(cell->net->net, "afsdb", cell->name, cell->name_len,
259 			"srv=1", &result, _expiry, true);
260 	if (ret < 0) {
261 		_leave(" = %d [dns]", ret);
262 		return ERR_PTR(ret);
263 	}
264 
265 	if (*_expiry == 0)
266 		*_expiry = ktime_get_real_seconds() + 60;
267 
268 	if (ret > 1 && result[0] == 0)
269 		vllist = afs_extract_vlserver_list(cell, result, ret);
270 	else
271 		vllist = afs_parse_text_addrs(cell->net, result, ret, ',',
272 					      VL_SERVICE, AFS_VL_PORT);
273 	kfree(result);
274 	if (IS_ERR(vllist) && vllist != ERR_PTR(-ENOMEM))
275 		pr_err("Failed to parse DNS data %ld\n", PTR_ERR(vllist));
276 
277 	return vllist;
278 }
279 
280 /*
281  * Merge an IPv4 entry into a fileserver address list.
282  */
afs_merge_fs_addr4(struct afs_net * net,struct afs_addr_list * alist,__be32 xdr,u16 port)283 int afs_merge_fs_addr4(struct afs_net *net, struct afs_addr_list *alist,
284 		       __be32 xdr, u16 port)
285 {
286 	struct sockaddr_rxrpc srx;
287 	struct rxrpc_peer *peer;
288 	int i;
289 
290 	if (alist->nr_addrs >= alist->max_addrs)
291 		return 0;
292 
293 	srx.srx_family = AF_RXRPC;
294 	srx.transport_type = SOCK_DGRAM;
295 	srx.transport_len = sizeof(srx.transport.sin);
296 	srx.transport.sin.sin_family = AF_INET;
297 	srx.transport.sin.sin_port = htons(port);
298 	srx.transport.sin.sin_addr.s_addr = xdr;
299 
300 	peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
301 	if (!peer)
302 		return -ENOMEM;
303 
304 	for (i = 0; i < alist->nr_ipv4; i++) {
305 		if (peer == alist->addrs[i].peer) {
306 			rxrpc_kernel_put_peer(peer);
307 			return 0;
308 		}
309 		if (peer <= alist->addrs[i].peer)
310 			break;
311 	}
312 
313 	if (i < alist->nr_addrs)
314 		memmove(alist->addrs + i + 1,
315 			alist->addrs + i,
316 			sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
317 
318 	alist->addrs[i].peer = peer;
319 	alist->nr_ipv4++;
320 	alist->nr_addrs++;
321 	return 0;
322 }
323 
324 /*
325  * Merge an IPv6 entry into a fileserver address list.
326  */
afs_merge_fs_addr6(struct afs_net * net,struct afs_addr_list * alist,__be32 * xdr,u16 port)327 int afs_merge_fs_addr6(struct afs_net *net, struct afs_addr_list *alist,
328 		       __be32 *xdr, u16 port)
329 {
330 	struct sockaddr_rxrpc srx;
331 	struct rxrpc_peer *peer;
332 	int i;
333 
334 	if (alist->nr_addrs >= alist->max_addrs)
335 		return 0;
336 
337 	srx.srx_family = AF_RXRPC;
338 	srx.transport_type = SOCK_DGRAM;
339 	srx.transport_len = sizeof(srx.transport.sin6);
340 	srx.transport.sin6.sin6_family = AF_INET6;
341 	srx.transport.sin6.sin6_port = htons(port);
342 	memcpy(&srx.transport.sin6.sin6_addr, xdr, 16);
343 
344 	peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
345 	if (!peer)
346 		return -ENOMEM;
347 
348 	for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
349 		if (peer == alist->addrs[i].peer) {
350 			rxrpc_kernel_put_peer(peer);
351 			return 0;
352 		}
353 		if (peer <= alist->addrs[i].peer)
354 			break;
355 	}
356 
357 	if (i < alist->nr_addrs)
358 		memmove(alist->addrs + i + 1,
359 			alist->addrs + i,
360 			sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
361 	alist->addrs[i].peer = peer;
362 	alist->nr_addrs++;
363 	return 0;
364 }
365