1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * xfrm_nat_keepalive.c
4  *
5  * (c) 2024 Eyal Birger <eyal.birger@gmail.com>
6  */
7 
8 #include <net/inet_common.h>
9 #include <net/ip6_checksum.h>
10 #include <net/xfrm.h>
11 
12 static DEFINE_PER_CPU(struct sock *, nat_keepalive_sk_ipv4);
13 #if IS_ENABLED(CONFIG_IPV6)
14 static DEFINE_PER_CPU(struct sock *, nat_keepalive_sk_ipv6);
15 #endif
16 
17 struct nat_keepalive {
18 	struct net *net;
19 	u16 family;
20 	xfrm_address_t saddr;
21 	xfrm_address_t daddr;
22 	__be16 encap_sport;
23 	__be16 encap_dport;
24 	__u32 smark;
25 };
26 
nat_keepalive_init(struct nat_keepalive * ka,struct xfrm_state * x)27 static void nat_keepalive_init(struct nat_keepalive *ka, struct xfrm_state *x)
28 {
29 	ka->net = xs_net(x);
30 	ka->family = x->props.family;
31 	ka->saddr = x->props.saddr;
32 	ka->daddr = x->id.daddr;
33 	ka->encap_sport = x->encap->encap_sport;
34 	ka->encap_dport = x->encap->encap_dport;
35 	ka->smark = xfrm_smark_get(0, x);
36 }
37 
nat_keepalive_send_ipv4(struct sk_buff * skb,struct nat_keepalive * ka)38 static int nat_keepalive_send_ipv4(struct sk_buff *skb,
39 				   struct nat_keepalive *ka)
40 {
41 	struct net *net = ka->net;
42 	struct flowi4 fl4;
43 	struct rtable *rt;
44 	struct sock *sk;
45 	__u8 tos = 0;
46 	int err;
47 
48 	flowi4_init_output(&fl4, 0 /* oif */, skb->mark, tos,
49 			   RT_SCOPE_UNIVERSE, IPPROTO_UDP, 0,
50 			   ka->daddr.a4, ka->saddr.a4, ka->encap_dport,
51 			   ka->encap_sport, sock_net_uid(net, NULL));
52 
53 	rt = ip_route_output_key(net, &fl4);
54 	if (IS_ERR(rt))
55 		return PTR_ERR(rt);
56 
57 	skb_dst_set(skb, &rt->dst);
58 
59 	sk = *this_cpu_ptr(&nat_keepalive_sk_ipv4);
60 	sock_net_set(sk, net);
61 	err = ip_build_and_send_pkt(skb, sk, fl4.saddr, fl4.daddr, NULL, tos);
62 	sock_net_set(sk, &init_net);
63 	return err;
64 }
65 
66 #if IS_ENABLED(CONFIG_IPV6)
nat_keepalive_send_ipv6(struct sk_buff * skb,struct nat_keepalive * ka,struct udphdr * uh)67 static int nat_keepalive_send_ipv6(struct sk_buff *skb,
68 				   struct nat_keepalive *ka,
69 				   struct udphdr *uh)
70 {
71 	struct net *net = ka->net;
72 	struct dst_entry *dst;
73 	struct flowi6 fl6;
74 	struct sock *sk;
75 	__wsum csum;
76 	int err;
77 
78 	csum = skb_checksum(skb, 0, skb->len, 0);
79 	uh->check = csum_ipv6_magic(&ka->saddr.in6, &ka->daddr.in6,
80 				    skb->len, IPPROTO_UDP, csum);
81 	if (uh->check == 0)
82 		uh->check = CSUM_MANGLED_0;
83 
84 	memset(&fl6, 0, sizeof(fl6));
85 	fl6.flowi6_mark = skb->mark;
86 	fl6.saddr = ka->saddr.in6;
87 	fl6.daddr = ka->daddr.in6;
88 	fl6.flowi6_proto = IPPROTO_UDP;
89 	fl6.fl6_sport = ka->encap_sport;
90 	fl6.fl6_dport = ka->encap_dport;
91 
92 	sk = *this_cpu_ptr(&nat_keepalive_sk_ipv6);
93 	sock_net_set(sk, net);
94 	dst = ipv6_stub->ipv6_dst_lookup_flow(net, sk, &fl6, NULL);
95 	if (IS_ERR(dst))
96 		return PTR_ERR(dst);
97 
98 	skb_dst_set(skb, dst);
99 	err = ipv6_stub->ip6_xmit(sk, skb, &fl6, skb->mark, NULL, 0, 0);
100 	sock_net_set(sk, &init_net);
101 	return err;
102 }
103 #endif
104 
nat_keepalive_send(struct nat_keepalive * ka)105 static void nat_keepalive_send(struct nat_keepalive *ka)
106 {
107 	const int nat_ka_hdrs_len = max(sizeof(struct iphdr),
108 					sizeof(struct ipv6hdr)) +
109 				    sizeof(struct udphdr);
110 	const u8 nat_ka_payload = 0xFF;
111 	int err = -EAFNOSUPPORT;
112 	struct sk_buff *skb;
113 	struct udphdr *uh;
114 
115 	skb = alloc_skb(nat_ka_hdrs_len + sizeof(nat_ka_payload), GFP_ATOMIC);
116 	if (unlikely(!skb))
117 		return;
118 
119 	skb_reserve(skb, nat_ka_hdrs_len);
120 
121 	skb_put_u8(skb, nat_ka_payload);
122 
123 	uh = skb_push(skb, sizeof(*uh));
124 	uh->source = ka->encap_sport;
125 	uh->dest = ka->encap_dport;
126 	uh->len = htons(skb->len);
127 	uh->check = 0;
128 
129 	skb->mark = ka->smark;
130 
131 	switch (ka->family) {
132 	case AF_INET:
133 		err = nat_keepalive_send_ipv4(skb, ka);
134 		break;
135 #if IS_ENABLED(CONFIG_IPV6)
136 	case AF_INET6:
137 		err = nat_keepalive_send_ipv6(skb, ka, uh);
138 		break;
139 #endif
140 	}
141 	if (err)
142 		kfree_skb(skb);
143 }
144 
145 struct nat_keepalive_work_ctx {
146 	time64_t next_run;
147 	time64_t now;
148 };
149 
nat_keepalive_work_single(struct xfrm_state * x,int count,void * ptr)150 static int nat_keepalive_work_single(struct xfrm_state *x, int count, void *ptr)
151 {
152 	struct nat_keepalive_work_ctx *ctx = ptr;
153 	bool send_keepalive = false;
154 	struct nat_keepalive ka;
155 	time64_t next_run;
156 	u32 interval;
157 	int delta;
158 
159 	interval = x->nat_keepalive_interval;
160 	if (!interval)
161 		return 0;
162 
163 	spin_lock(&x->lock);
164 
165 	delta = (int)(ctx->now - x->lastused);
166 	if (delta < interval) {
167 		x->nat_keepalive_expiration = ctx->now + interval - delta;
168 		next_run = x->nat_keepalive_expiration;
169 	} else if (x->nat_keepalive_expiration > ctx->now) {
170 		next_run = x->nat_keepalive_expiration;
171 	} else {
172 		next_run = ctx->now + interval;
173 		nat_keepalive_init(&ka, x);
174 		send_keepalive = true;
175 	}
176 
177 	spin_unlock(&x->lock);
178 
179 	if (send_keepalive)
180 		nat_keepalive_send(&ka);
181 
182 	if (!ctx->next_run || next_run < ctx->next_run)
183 		ctx->next_run = next_run;
184 	return 0;
185 }
186 
nat_keepalive_work(struct work_struct * work)187 static void nat_keepalive_work(struct work_struct *work)
188 {
189 	struct nat_keepalive_work_ctx ctx;
190 	struct xfrm_state_walk walk;
191 	struct net *net;
192 
193 	ctx.next_run = 0;
194 	ctx.now = ktime_get_real_seconds();
195 
196 	net = container_of(work, struct net, xfrm.nat_keepalive_work.work);
197 	xfrm_state_walk_init(&walk, IPPROTO_ESP, NULL);
198 	xfrm_state_walk(net, &walk, nat_keepalive_work_single, &ctx);
199 	xfrm_state_walk_done(&walk, net);
200 	if (ctx.next_run)
201 		schedule_delayed_work(&net->xfrm.nat_keepalive_work,
202 				      (ctx.next_run - ctx.now) * HZ);
203 }
204 
nat_keepalive_sk_init(struct sock * __percpu * socks,unsigned short family)205 static int nat_keepalive_sk_init(struct sock * __percpu *socks,
206 				 unsigned short family)
207 {
208 	struct sock *sk;
209 	int err, i;
210 
211 	for_each_possible_cpu(i) {
212 		err = inet_ctl_sock_create(&sk, family, SOCK_RAW, IPPROTO_UDP,
213 					   &init_net);
214 		if (err < 0)
215 			goto err;
216 
217 		*per_cpu_ptr(socks, i) = sk;
218 	}
219 
220 	return 0;
221 err:
222 	for_each_possible_cpu(i)
223 		inet_ctl_sock_destroy(*per_cpu_ptr(socks, i));
224 	return err;
225 }
226 
nat_keepalive_sk_fini(struct sock * __percpu * socks)227 static void nat_keepalive_sk_fini(struct sock * __percpu *socks)
228 {
229 	int i;
230 
231 	for_each_possible_cpu(i)
232 		inet_ctl_sock_destroy(*per_cpu_ptr(socks, i));
233 }
234 
xfrm_nat_keepalive_state_updated(struct xfrm_state * x)235 void xfrm_nat_keepalive_state_updated(struct xfrm_state *x)
236 {
237 	struct net *net;
238 
239 	if (!x->nat_keepalive_interval)
240 		return;
241 
242 	net = xs_net(x);
243 	schedule_delayed_work(&net->xfrm.nat_keepalive_work, 0);
244 }
245 
xfrm_nat_keepalive_net_init(struct net * net)246 int __net_init xfrm_nat_keepalive_net_init(struct net *net)
247 {
248 	INIT_DELAYED_WORK(&net->xfrm.nat_keepalive_work, nat_keepalive_work);
249 	return 0;
250 }
251 
xfrm_nat_keepalive_net_fini(struct net * net)252 int xfrm_nat_keepalive_net_fini(struct net *net)
253 {
254 	cancel_delayed_work_sync(&net->xfrm.nat_keepalive_work);
255 	return 0;
256 }
257 
xfrm_nat_keepalive_init(unsigned short family)258 int xfrm_nat_keepalive_init(unsigned short family)
259 {
260 	int err = -EAFNOSUPPORT;
261 
262 	switch (family) {
263 	case AF_INET:
264 		err = nat_keepalive_sk_init(&nat_keepalive_sk_ipv4, PF_INET);
265 		break;
266 #if IS_ENABLED(CONFIG_IPV6)
267 	case AF_INET6:
268 		err = nat_keepalive_sk_init(&nat_keepalive_sk_ipv6, PF_INET6);
269 		break;
270 #endif
271 	}
272 
273 	if (err)
274 		pr_err("xfrm nat keepalive init: failed to init err:%d\n", err);
275 	return err;
276 }
277 EXPORT_SYMBOL_GPL(xfrm_nat_keepalive_init);
278 
xfrm_nat_keepalive_fini(unsigned short family)279 void xfrm_nat_keepalive_fini(unsigned short family)
280 {
281 	switch (family) {
282 	case AF_INET:
283 		nat_keepalive_sk_fini(&nat_keepalive_sk_ipv4);
284 		break;
285 #if IS_ENABLED(CONFIG_IPV6)
286 	case AF_INET6:
287 		nat_keepalive_sk_fini(&nat_keepalive_sk_ipv6);
288 		break;
289 #endif
290 	}
291 }
292 EXPORT_SYMBOL_GPL(xfrm_nat_keepalive_fini);
293