1 /* SPDX-License-Identifier: GPL-2.0 */
2 #include <linux/module.h>
3 #include <linux/netfilter/nf_tables.h>
4 #include <net/netfilter/nf_tables.h>
5 #include <net/netfilter/nf_tables_core.h>
6 #include <net/netfilter/nf_socket.h>
7 #include <net/inet_sock.h>
8 #include <net/tcp.h>
9 
10 struct nft_socket {
11 	enum nft_socket_keys		key:8;
12 	u8				level;		/* cgroupv2 level to extract */
13 	u8				level_user;	/* cgroupv2 level provided by userspace */
14 	u8				len;
15 	union {
16 		u8			dreg;
17 	};
18 };
19 
nft_socket_wildcard(const struct nft_pktinfo * pkt,struct nft_regs * regs,struct sock * sk,u32 * dest)20 static void nft_socket_wildcard(const struct nft_pktinfo *pkt,
21 				struct nft_regs *regs, struct sock *sk,
22 				u32 *dest)
23 {
24 	switch (nft_pf(pkt)) {
25 	case NFPROTO_IPV4:
26 		nft_reg_store8(dest, inet_sk(sk)->inet_rcv_saddr == 0);
27 		break;
28 #if IS_ENABLED(CONFIG_NF_TABLES_IPV6)
29 	case NFPROTO_IPV6:
30 		nft_reg_store8(dest, ipv6_addr_any(&sk->sk_v6_rcv_saddr));
31 		break;
32 #endif
33 	default:
34 		regs->verdict.code = NFT_BREAK;
35 		return;
36 	}
37 }
38 
39 #ifdef CONFIG_SOCK_CGROUP_DATA
40 static noinline bool
nft_sock_get_eval_cgroupv2(u32 * dest,struct sock * sk,const struct nft_pktinfo * pkt,u32 level)41 nft_sock_get_eval_cgroupv2(u32 *dest, struct sock *sk, const struct nft_pktinfo *pkt, u32 level)
42 {
43 	struct cgroup *cgrp;
44 	u64 cgid;
45 
46 	if (!sk_fullsock(sk))
47 		return false;
48 
49 	cgrp = cgroup_ancestor(sock_cgroup_ptr(&sk->sk_cgrp_data), level);
50 	if (!cgrp)
51 		return false;
52 
53 	cgid = cgroup_id(cgrp);
54 	memcpy(dest, &cgid, sizeof(u64));
55 	return true;
56 }
57 
58 /* process context only, uses current->nsproxy. */
nft_socket_cgroup_subtree_level(void)59 static noinline int nft_socket_cgroup_subtree_level(void)
60 {
61 	struct cgroup *cgrp = cgroup_get_from_path("/");
62 	int level;
63 
64 	if (IS_ERR(cgrp))
65 		return PTR_ERR(cgrp);
66 
67 	level = cgrp->level;
68 
69 	cgroup_put(cgrp);
70 
71 	if (WARN_ON_ONCE(level > 255))
72 		return -ERANGE;
73 
74 	if (WARN_ON_ONCE(level < 0))
75 		return -EINVAL;
76 
77 	return level;
78 }
79 #endif
80 
nft_socket_do_lookup(const struct nft_pktinfo * pkt)81 static struct sock *nft_socket_do_lookup(const struct nft_pktinfo *pkt)
82 {
83 	const struct net_device *indev = nft_in(pkt);
84 	const struct sk_buff *skb = pkt->skb;
85 	struct sock *sk = NULL;
86 
87 	if (!indev)
88 		return NULL;
89 
90 	switch (nft_pf(pkt)) {
91 	case NFPROTO_IPV4:
92 		sk = nf_sk_lookup_slow_v4(nft_net(pkt), skb, indev);
93 		break;
94 #if IS_ENABLED(CONFIG_NF_TABLES_IPV6)
95 	case NFPROTO_IPV6:
96 		sk = nf_sk_lookup_slow_v6(nft_net(pkt), skb, indev);
97 		break;
98 #endif
99 	default:
100 		WARN_ON_ONCE(1);
101 		break;
102 	}
103 
104 	return sk;
105 }
106 
nft_socket_eval(const struct nft_expr * expr,struct nft_regs * regs,const struct nft_pktinfo * pkt)107 static void nft_socket_eval(const struct nft_expr *expr,
108 			    struct nft_regs *regs,
109 			    const struct nft_pktinfo *pkt)
110 {
111 	const struct nft_socket *priv = nft_expr_priv(expr);
112 	struct sk_buff *skb = pkt->skb;
113 	struct sock *sk = skb->sk;
114 	u32 *dest = &regs->data[priv->dreg];
115 
116 	if (sk && !net_eq(nft_net(pkt), sock_net(sk)))
117 		sk = NULL;
118 
119 	if (!sk)
120 		sk = nft_socket_do_lookup(pkt);
121 
122 	if (!sk) {
123 		regs->verdict.code = NFT_BREAK;
124 		return;
125 	}
126 
127 	switch(priv->key) {
128 	case NFT_SOCKET_TRANSPARENT:
129 		nft_reg_store8(dest, inet_sk_transparent(sk));
130 		break;
131 	case NFT_SOCKET_MARK:
132 		if (sk_fullsock(sk)) {
133 			*dest = READ_ONCE(sk->sk_mark);
134 		} else {
135 			regs->verdict.code = NFT_BREAK;
136 			goto out_put_sk;
137 		}
138 		break;
139 	case NFT_SOCKET_WILDCARD:
140 		if (!sk_fullsock(sk)) {
141 			regs->verdict.code = NFT_BREAK;
142 			goto out_put_sk;
143 		}
144 		nft_socket_wildcard(pkt, regs, sk, dest);
145 		break;
146 #ifdef CONFIG_SOCK_CGROUP_DATA
147 	case NFT_SOCKET_CGROUPV2:
148 		if (!nft_sock_get_eval_cgroupv2(dest, sk, pkt, priv->level)) {
149 			regs->verdict.code = NFT_BREAK;
150 			goto out_put_sk;
151 		}
152 		break;
153 #endif
154 	default:
155 		WARN_ON(1);
156 		regs->verdict.code = NFT_BREAK;
157 	}
158 
159 out_put_sk:
160 	if (sk != skb->sk)
161 		sock_gen_put(sk);
162 }
163 
164 static const struct nla_policy nft_socket_policy[NFTA_SOCKET_MAX + 1] = {
165 	[NFTA_SOCKET_KEY]		= NLA_POLICY_MAX(NLA_BE32, 255),
166 	[NFTA_SOCKET_DREG]		= { .type = NLA_U32 },
167 	[NFTA_SOCKET_LEVEL]		= NLA_POLICY_MAX(NLA_BE32, 255),
168 };
169 
nft_socket_init(const struct nft_ctx * ctx,const struct nft_expr * expr,const struct nlattr * const tb[])170 static int nft_socket_init(const struct nft_ctx *ctx,
171 			   const struct nft_expr *expr,
172 			   const struct nlattr * const tb[])
173 {
174 	struct nft_socket *priv = nft_expr_priv(expr);
175 	unsigned int len;
176 
177 	if (!tb[NFTA_SOCKET_DREG] || !tb[NFTA_SOCKET_KEY])
178 		return -EINVAL;
179 
180 	switch(ctx->family) {
181 	case NFPROTO_IPV4:
182 #if IS_ENABLED(CONFIG_NF_TABLES_IPV6)
183 	case NFPROTO_IPV6:
184 #endif
185 	case NFPROTO_INET:
186 		break;
187 	default:
188 		return -EOPNOTSUPP;
189 	}
190 
191 	priv->key = ntohl(nla_get_be32(tb[NFTA_SOCKET_KEY]));
192 	switch(priv->key) {
193 	case NFT_SOCKET_TRANSPARENT:
194 	case NFT_SOCKET_WILDCARD:
195 		len = sizeof(u8);
196 		break;
197 	case NFT_SOCKET_MARK:
198 		len = sizeof(u32);
199 		break;
200 #ifdef CONFIG_SOCK_CGROUP_DATA
201 	case NFT_SOCKET_CGROUPV2: {
202 		unsigned int level;
203 		int err;
204 
205 		if (!tb[NFTA_SOCKET_LEVEL])
206 			return -EINVAL;
207 
208 		level = ntohl(nla_get_be32(tb[NFTA_SOCKET_LEVEL]));
209 		if (level > 255)
210 			return -EOPNOTSUPP;
211 
212 		err = nft_socket_cgroup_subtree_level();
213 		if (err < 0)
214 			return err;
215 
216 		priv->level_user = level;
217 
218 		level += err;
219 		/* Implies a giant cgroup tree */
220 		if (WARN_ON_ONCE(level > 255))
221 			return -EOPNOTSUPP;
222 
223 		priv->level = level;
224 		len = sizeof(u64);
225 		break;
226 	}
227 #endif
228 	default:
229 		return -EOPNOTSUPP;
230 	}
231 
232 	priv->len = len;
233 	return nft_parse_register_store(ctx, tb[NFTA_SOCKET_DREG], &priv->dreg,
234 					NULL, NFT_DATA_VALUE, len);
235 }
236 
nft_socket_dump(struct sk_buff * skb,const struct nft_expr * expr,bool reset)237 static int nft_socket_dump(struct sk_buff *skb,
238 			   const struct nft_expr *expr, bool reset)
239 {
240 	const struct nft_socket *priv = nft_expr_priv(expr);
241 
242 	if (nla_put_be32(skb, NFTA_SOCKET_KEY, htonl(priv->key)))
243 		return -1;
244 	if (nft_dump_register(skb, NFTA_SOCKET_DREG, priv->dreg))
245 		return -1;
246 	if (priv->key == NFT_SOCKET_CGROUPV2 &&
247 	    nla_put_be32(skb, NFTA_SOCKET_LEVEL, htonl(priv->level_user)))
248 		return -1;
249 	return 0;
250 }
251 
nft_socket_reduce(struct nft_regs_track * track,const struct nft_expr * expr)252 static bool nft_socket_reduce(struct nft_regs_track *track,
253 			      const struct nft_expr *expr)
254 {
255 	const struct nft_socket *priv = nft_expr_priv(expr);
256 	const struct nft_socket *socket;
257 
258 	if (!nft_reg_track_cmp(track, expr, priv->dreg)) {
259 		nft_reg_track_update(track, expr, priv->dreg, priv->len);
260 		return false;
261 	}
262 
263 	socket = nft_expr_priv(track->regs[priv->dreg].selector);
264 	if (priv->key != socket->key ||
265 	    priv->dreg != socket->dreg ||
266 	    priv->level != socket->level) {
267 		nft_reg_track_update(track, expr, priv->dreg, priv->len);
268 		return false;
269 	}
270 
271 	if (!track->regs[priv->dreg].bitwise)
272 		return true;
273 
274 	return nft_expr_reduce_bitwise(track, expr);
275 }
276 
nft_socket_validate(const struct nft_ctx * ctx,const struct nft_expr * expr)277 static int nft_socket_validate(const struct nft_ctx *ctx,
278 			       const struct nft_expr *expr)
279 {
280 	if (ctx->family != NFPROTO_IPV4 &&
281 	    ctx->family != NFPROTO_IPV6 &&
282 	    ctx->family != NFPROTO_INET)
283 		return -EOPNOTSUPP;
284 
285 	return nft_chain_validate_hooks(ctx->chain,
286 					(1 << NF_INET_PRE_ROUTING) |
287 					(1 << NF_INET_LOCAL_IN) |
288 					(1 << NF_INET_LOCAL_OUT));
289 }
290 
291 static struct nft_expr_type nft_socket_type;
292 static const struct nft_expr_ops nft_socket_ops = {
293 	.type		= &nft_socket_type,
294 	.size		= NFT_EXPR_SIZE(sizeof(struct nft_socket)),
295 	.eval		= nft_socket_eval,
296 	.init		= nft_socket_init,
297 	.dump		= nft_socket_dump,
298 	.validate	= nft_socket_validate,
299 	.reduce		= nft_socket_reduce,
300 };
301 
302 static struct nft_expr_type nft_socket_type __read_mostly = {
303 	.name		= "socket",
304 	.ops		= &nft_socket_ops,
305 	.policy		= nft_socket_policy,
306 	.maxattr	= NFTA_SOCKET_MAX,
307 	.owner		= THIS_MODULE,
308 };
309 
nft_socket_module_init(void)310 static int __init nft_socket_module_init(void)
311 {
312 	return nft_register_expr(&nft_socket_type);
313 }
314 
nft_socket_module_exit(void)315 static void __exit nft_socket_module_exit(void)
316 {
317 	nft_unregister_expr(&nft_socket_type);
318 }
319 
320 module_init(nft_socket_module_init);
321 module_exit(nft_socket_module_exit);
322 
323 MODULE_LICENSE("GPL");
324 MODULE_AUTHOR("Máté Eckl");
325 MODULE_DESCRIPTION("nf_tables socket match module");
326 MODULE_ALIAS_NFT_EXPR("socket");
327