1  // SPDX-License-Identifier: GPL-2.0-only
2  /* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */
3  #include <linux/mm.h>
4  #include <linux/llist.h>
5  #include <linux/bpf.h>
6  #include <linux/irq_work.h>
7  #include <linux/bpf_mem_alloc.h>
8  #include <linux/memcontrol.h>
9  #include <asm/local.h>
10  
11  /* Any context (including NMI) BPF specific memory allocator.
12   *
13   * Tracing BPF programs can attach to kprobe and fentry. Hence they
14   * run in unknown context where calling plain kmalloc() might not be safe.
15   *
16   * Front-end kmalloc() with per-cpu per-bucket cache of free elements.
17   * Refill this cache asynchronously from irq_work.
18   *
19   * CPU_0 buckets
20   * 16 32 64 96 128 196 256 512 1024 2048 4096
21   * ...
22   * CPU_N buckets
23   * 16 32 64 96 128 196 256 512 1024 2048 4096
24   *
25   * The buckets are prefilled at the start.
26   * BPF programs always run with migration disabled.
27   * It's safe to allocate from cache of the current cpu with irqs disabled.
28   * Free-ing is always done into bucket of the current cpu as well.
29   * irq_work trims extra free elements from buckets with kfree
30   * and refills them with kmalloc, so global kmalloc logic takes care
31   * of freeing objects allocated by one cpu and freed on another.
32   *
33   * Every allocated objected is padded with extra 8 bytes that contains
34   * struct llist_node.
35   */
36  #define LLIST_NODE_SZ sizeof(struct llist_node)
37  
38  #define BPF_MEM_ALLOC_SIZE_MAX 4096
39  
40  /* similar to kmalloc, but sizeof == 8 bucket is gone */
41  static u8 size_index[24] __ro_after_init = {
42  	3,	/* 8 */
43  	3,	/* 16 */
44  	4,	/* 24 */
45  	4,	/* 32 */
46  	5,	/* 40 */
47  	5,	/* 48 */
48  	5,	/* 56 */
49  	5,	/* 64 */
50  	1,	/* 72 */
51  	1,	/* 80 */
52  	1,	/* 88 */
53  	1,	/* 96 */
54  	6,	/* 104 */
55  	6,	/* 112 */
56  	6,	/* 120 */
57  	6,	/* 128 */
58  	2,	/* 136 */
59  	2,	/* 144 */
60  	2,	/* 152 */
61  	2,	/* 160 */
62  	2,	/* 168 */
63  	2,	/* 176 */
64  	2,	/* 184 */
65  	2	/* 192 */
66  };
67  
bpf_mem_cache_idx(size_t size)68  static int bpf_mem_cache_idx(size_t size)
69  {
70  	if (!size || size > BPF_MEM_ALLOC_SIZE_MAX)
71  		return -1;
72  
73  	if (size <= 192)
74  		return size_index[(size - 1) / 8] - 1;
75  
76  	return fls(size - 1) - 2;
77  }
78  
79  #define NUM_CACHES 11
80  
81  struct bpf_mem_cache {
82  	/* per-cpu list of free objects of size 'unit_size'.
83  	 * All accesses are done with interrupts disabled and 'active' counter
84  	 * protection with __llist_add() and __llist_del_first().
85  	 */
86  	struct llist_head free_llist;
87  	local_t active;
88  
89  	/* Operations on the free_list from unit_alloc/unit_free/bpf_mem_refill
90  	 * are sequenced by per-cpu 'active' counter. But unit_free() cannot
91  	 * fail. When 'active' is busy the unit_free() will add an object to
92  	 * free_llist_extra.
93  	 */
94  	struct llist_head free_llist_extra;
95  
96  	struct irq_work refill_work;
97  	struct obj_cgroup *objcg;
98  	int unit_size;
99  	/* count of objects in free_llist */
100  	int free_cnt;
101  	int low_watermark, high_watermark, batch;
102  	int percpu_size;
103  	bool draining;
104  	struct bpf_mem_cache *tgt;
105  
106  	/* list of objects to be freed after RCU GP */
107  	struct llist_head free_by_rcu;
108  	struct llist_node *free_by_rcu_tail;
109  	struct llist_head waiting_for_gp;
110  	struct llist_node *waiting_for_gp_tail;
111  	struct rcu_head rcu;
112  	atomic_t call_rcu_in_progress;
113  	struct llist_head free_llist_extra_rcu;
114  
115  	/* list of objects to be freed after RCU tasks trace GP */
116  	struct llist_head free_by_rcu_ttrace;
117  	struct llist_head waiting_for_gp_ttrace;
118  	struct rcu_head rcu_ttrace;
119  	atomic_t call_rcu_ttrace_in_progress;
120  };
121  
122  struct bpf_mem_caches {
123  	struct bpf_mem_cache cache[NUM_CACHES];
124  };
125  
126  static const u16 sizes[NUM_CACHES] = {96, 192, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096};
127  
__llist_del_first(struct llist_head * head)128  static struct llist_node notrace *__llist_del_first(struct llist_head *head)
129  {
130  	struct llist_node *entry, *next;
131  
132  	entry = head->first;
133  	if (!entry)
134  		return NULL;
135  	next = entry->next;
136  	head->first = next;
137  	return entry;
138  }
139  
__alloc(struct bpf_mem_cache * c,int node,gfp_t flags)140  static void *__alloc(struct bpf_mem_cache *c, int node, gfp_t flags)
141  {
142  	if (c->percpu_size) {
143  		void __percpu **obj = kmalloc_node(c->percpu_size, flags, node);
144  		void __percpu *pptr = __alloc_percpu_gfp(c->unit_size, 8, flags);
145  
146  		if (!obj || !pptr) {
147  			free_percpu(pptr);
148  			kfree(obj);
149  			return NULL;
150  		}
151  		obj[1] = pptr;
152  		return obj;
153  	}
154  
155  	return kmalloc_node(c->unit_size, flags | __GFP_ZERO, node);
156  }
157  
get_memcg(const struct bpf_mem_cache * c)158  static struct mem_cgroup *get_memcg(const struct bpf_mem_cache *c)
159  {
160  #ifdef CONFIG_MEMCG
161  	if (c->objcg)
162  		return get_mem_cgroup_from_objcg(c->objcg);
163  	return root_mem_cgroup;
164  #else
165  	return NULL;
166  #endif
167  }
168  
inc_active(struct bpf_mem_cache * c,unsigned long * flags)169  static void inc_active(struct bpf_mem_cache *c, unsigned long *flags)
170  {
171  	if (IS_ENABLED(CONFIG_PREEMPT_RT))
172  		/* In RT irq_work runs in per-cpu kthread, so disable
173  		 * interrupts to avoid preemption and interrupts and
174  		 * reduce the chance of bpf prog executing on this cpu
175  		 * when active counter is busy.
176  		 */
177  		local_irq_save(*flags);
178  	/* alloc_bulk runs from irq_work which will not preempt a bpf
179  	 * program that does unit_alloc/unit_free since IRQs are
180  	 * disabled there. There is no race to increment 'active'
181  	 * counter. It protects free_llist from corruption in case NMI
182  	 * bpf prog preempted this loop.
183  	 */
184  	WARN_ON_ONCE(local_inc_return(&c->active) != 1);
185  }
186  
dec_active(struct bpf_mem_cache * c,unsigned long * flags)187  static void dec_active(struct bpf_mem_cache *c, unsigned long *flags)
188  {
189  	local_dec(&c->active);
190  	if (IS_ENABLED(CONFIG_PREEMPT_RT))
191  		local_irq_restore(*flags);
192  }
193  
add_obj_to_free_list(struct bpf_mem_cache * c,void * obj)194  static void add_obj_to_free_list(struct bpf_mem_cache *c, void *obj)
195  {
196  	unsigned long flags;
197  
198  	inc_active(c, &flags);
199  	__llist_add(obj, &c->free_llist);
200  	c->free_cnt++;
201  	dec_active(c, &flags);
202  }
203  
204  /* Mostly runs from irq_work except __init phase. */
alloc_bulk(struct bpf_mem_cache * c,int cnt,int node,bool atomic)205  static void alloc_bulk(struct bpf_mem_cache *c, int cnt, int node, bool atomic)
206  {
207  	struct mem_cgroup *memcg = NULL, *old_memcg;
208  	gfp_t gfp;
209  	void *obj;
210  	int i;
211  
212  	gfp = __GFP_NOWARN | __GFP_ACCOUNT;
213  	gfp |= atomic ? GFP_NOWAIT : GFP_KERNEL;
214  
215  	for (i = 0; i < cnt; i++) {
216  		/*
217  		 * For every 'c' llist_del_first(&c->free_by_rcu_ttrace); is
218  		 * done only by one CPU == current CPU. Other CPUs might
219  		 * llist_add() and llist_del_all() in parallel.
220  		 */
221  		obj = llist_del_first(&c->free_by_rcu_ttrace);
222  		if (!obj)
223  			break;
224  		add_obj_to_free_list(c, obj);
225  	}
226  	if (i >= cnt)
227  		return;
228  
229  	for (; i < cnt; i++) {
230  		obj = llist_del_first(&c->waiting_for_gp_ttrace);
231  		if (!obj)
232  			break;
233  		add_obj_to_free_list(c, obj);
234  	}
235  	if (i >= cnt)
236  		return;
237  
238  	memcg = get_memcg(c);
239  	old_memcg = set_active_memcg(memcg);
240  	for (; i < cnt; i++) {
241  		/* Allocate, but don't deplete atomic reserves that typical
242  		 * GFP_ATOMIC would do. irq_work runs on this cpu and kmalloc
243  		 * will allocate from the current numa node which is what we
244  		 * want here.
245  		 */
246  		obj = __alloc(c, node, gfp);
247  		if (!obj)
248  			break;
249  		add_obj_to_free_list(c, obj);
250  	}
251  	set_active_memcg(old_memcg);
252  	mem_cgroup_put(memcg);
253  }
254  
free_one(void * obj,bool percpu)255  static void free_one(void *obj, bool percpu)
256  {
257  	if (percpu) {
258  		free_percpu(((void __percpu **)obj)[1]);
259  		kfree(obj);
260  		return;
261  	}
262  
263  	kfree(obj);
264  }
265  
free_all(struct llist_node * llnode,bool percpu)266  static int free_all(struct llist_node *llnode, bool percpu)
267  {
268  	struct llist_node *pos, *t;
269  	int cnt = 0;
270  
271  	llist_for_each_safe(pos, t, llnode) {
272  		free_one(pos, percpu);
273  		cnt++;
274  	}
275  	return cnt;
276  }
277  
__free_rcu(struct rcu_head * head)278  static void __free_rcu(struct rcu_head *head)
279  {
280  	struct bpf_mem_cache *c = container_of(head, struct bpf_mem_cache, rcu_ttrace);
281  
282  	free_all(llist_del_all(&c->waiting_for_gp_ttrace), !!c->percpu_size);
283  	atomic_set(&c->call_rcu_ttrace_in_progress, 0);
284  }
285  
__free_rcu_tasks_trace(struct rcu_head * head)286  static void __free_rcu_tasks_trace(struct rcu_head *head)
287  {
288  	/* If RCU Tasks Trace grace period implies RCU grace period,
289  	 * there is no need to invoke call_rcu().
290  	 */
291  	if (rcu_trace_implies_rcu_gp())
292  		__free_rcu(head);
293  	else
294  		call_rcu(head, __free_rcu);
295  }
296  
enque_to_free(struct bpf_mem_cache * c,void * obj)297  static void enque_to_free(struct bpf_mem_cache *c, void *obj)
298  {
299  	struct llist_node *llnode = obj;
300  
301  	/* bpf_mem_cache is a per-cpu object. Freeing happens in irq_work.
302  	 * Nothing races to add to free_by_rcu_ttrace list.
303  	 */
304  	llist_add(llnode, &c->free_by_rcu_ttrace);
305  }
306  
do_call_rcu_ttrace(struct bpf_mem_cache * c)307  static void do_call_rcu_ttrace(struct bpf_mem_cache *c)
308  {
309  	struct llist_node *llnode, *t;
310  
311  	if (atomic_xchg(&c->call_rcu_ttrace_in_progress, 1)) {
312  		if (unlikely(READ_ONCE(c->draining))) {
313  			llnode = llist_del_all(&c->free_by_rcu_ttrace);
314  			free_all(llnode, !!c->percpu_size);
315  		}
316  		return;
317  	}
318  
319  	WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp_ttrace));
320  	llist_for_each_safe(llnode, t, llist_del_all(&c->free_by_rcu_ttrace))
321  		llist_add(llnode, &c->waiting_for_gp_ttrace);
322  
323  	if (unlikely(READ_ONCE(c->draining))) {
324  		__free_rcu(&c->rcu_ttrace);
325  		return;
326  	}
327  
328  	/* Use call_rcu_tasks_trace() to wait for sleepable progs to finish.
329  	 * If RCU Tasks Trace grace period implies RCU grace period, free
330  	 * these elements directly, else use call_rcu() to wait for normal
331  	 * progs to finish and finally do free_one() on each element.
332  	 */
333  	call_rcu_tasks_trace(&c->rcu_ttrace, __free_rcu_tasks_trace);
334  }
335  
free_bulk(struct bpf_mem_cache * c)336  static void free_bulk(struct bpf_mem_cache *c)
337  {
338  	struct bpf_mem_cache *tgt = c->tgt;
339  	struct llist_node *llnode, *t;
340  	unsigned long flags;
341  	int cnt;
342  
343  	WARN_ON_ONCE(tgt->unit_size != c->unit_size);
344  	WARN_ON_ONCE(tgt->percpu_size != c->percpu_size);
345  
346  	do {
347  		inc_active(c, &flags);
348  		llnode = __llist_del_first(&c->free_llist);
349  		if (llnode)
350  			cnt = --c->free_cnt;
351  		else
352  			cnt = 0;
353  		dec_active(c, &flags);
354  		if (llnode)
355  			enque_to_free(tgt, llnode);
356  	} while (cnt > (c->high_watermark + c->low_watermark) / 2);
357  
358  	/* and drain free_llist_extra */
359  	llist_for_each_safe(llnode, t, llist_del_all(&c->free_llist_extra))
360  		enque_to_free(tgt, llnode);
361  	do_call_rcu_ttrace(tgt);
362  }
363  
__free_by_rcu(struct rcu_head * head)364  static void __free_by_rcu(struct rcu_head *head)
365  {
366  	struct bpf_mem_cache *c = container_of(head, struct bpf_mem_cache, rcu);
367  	struct bpf_mem_cache *tgt = c->tgt;
368  	struct llist_node *llnode;
369  
370  	WARN_ON_ONCE(tgt->unit_size != c->unit_size);
371  	WARN_ON_ONCE(tgt->percpu_size != c->percpu_size);
372  
373  	llnode = llist_del_all(&c->waiting_for_gp);
374  	if (!llnode)
375  		goto out;
376  
377  	llist_add_batch(llnode, c->waiting_for_gp_tail, &tgt->free_by_rcu_ttrace);
378  
379  	/* Objects went through regular RCU GP. Send them to RCU tasks trace */
380  	do_call_rcu_ttrace(tgt);
381  out:
382  	atomic_set(&c->call_rcu_in_progress, 0);
383  }
384  
check_free_by_rcu(struct bpf_mem_cache * c)385  static void check_free_by_rcu(struct bpf_mem_cache *c)
386  {
387  	struct llist_node *llnode, *t;
388  	unsigned long flags;
389  
390  	/* drain free_llist_extra_rcu */
391  	if (unlikely(!llist_empty(&c->free_llist_extra_rcu))) {
392  		inc_active(c, &flags);
393  		llist_for_each_safe(llnode, t, llist_del_all(&c->free_llist_extra_rcu))
394  			if (__llist_add(llnode, &c->free_by_rcu))
395  				c->free_by_rcu_tail = llnode;
396  		dec_active(c, &flags);
397  	}
398  
399  	if (llist_empty(&c->free_by_rcu))
400  		return;
401  
402  	if (atomic_xchg(&c->call_rcu_in_progress, 1)) {
403  		/*
404  		 * Instead of kmalloc-ing new rcu_head and triggering 10k
405  		 * call_rcu() to hit rcutree.qhimark and force RCU to notice
406  		 * the overload just ask RCU to hurry up. There could be many
407  		 * objects in free_by_rcu list.
408  		 * This hint reduces memory consumption for an artificial
409  		 * benchmark from 2 Gbyte to 150 Mbyte.
410  		 */
411  		rcu_request_urgent_qs_task(current);
412  		return;
413  	}
414  
415  	WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp));
416  
417  	inc_active(c, &flags);
418  	WRITE_ONCE(c->waiting_for_gp.first, __llist_del_all(&c->free_by_rcu));
419  	c->waiting_for_gp_tail = c->free_by_rcu_tail;
420  	dec_active(c, &flags);
421  
422  	if (unlikely(READ_ONCE(c->draining))) {
423  		free_all(llist_del_all(&c->waiting_for_gp), !!c->percpu_size);
424  		atomic_set(&c->call_rcu_in_progress, 0);
425  	} else {
426  		call_rcu_hurry(&c->rcu, __free_by_rcu);
427  	}
428  }
429  
bpf_mem_refill(struct irq_work * work)430  static void bpf_mem_refill(struct irq_work *work)
431  {
432  	struct bpf_mem_cache *c = container_of(work, struct bpf_mem_cache, refill_work);
433  	int cnt;
434  
435  	/* Racy access to free_cnt. It doesn't need to be 100% accurate */
436  	cnt = c->free_cnt;
437  	if (cnt < c->low_watermark)
438  		/* irq_work runs on this cpu and kmalloc will allocate
439  		 * from the current numa node which is what we want here.
440  		 */
441  		alloc_bulk(c, c->batch, NUMA_NO_NODE, true);
442  	else if (cnt > c->high_watermark)
443  		free_bulk(c);
444  
445  	check_free_by_rcu(c);
446  }
447  
irq_work_raise(struct bpf_mem_cache * c)448  static void notrace irq_work_raise(struct bpf_mem_cache *c)
449  {
450  	irq_work_queue(&c->refill_work);
451  }
452  
453  /* For typical bpf map case that uses bpf_mem_cache_alloc and single bucket
454   * the freelist cache will be elem_size * 64 (or less) on each cpu.
455   *
456   * For bpf programs that don't have statically known allocation sizes and
457   * assuming (low_mark + high_mark) / 2 as an average number of elements per
458   * bucket and all buckets are used the total amount of memory in freelists
459   * on each cpu will be:
460   * 64*16 + 64*32 + 64*64 + 64*96 + 64*128 + 64*196 + 64*256 + 32*512 + 16*1024 + 8*2048 + 4*4096
461   * == ~ 116 Kbyte using below heuristic.
462   * Initialized, but unused bpf allocator (not bpf map specific one) will
463   * consume ~ 11 Kbyte per cpu.
464   * Typical case will be between 11K and 116K closer to 11K.
465   * bpf progs can and should share bpf_mem_cache when possible.
466   *
467   * Percpu allocation is typically rare. To avoid potential unnecessary large
468   * memory consumption, set low_mark = 1 and high_mark = 3, resulting in c->batch = 1.
469   */
init_refill_work(struct bpf_mem_cache * c)470  static void init_refill_work(struct bpf_mem_cache *c)
471  {
472  	init_irq_work(&c->refill_work, bpf_mem_refill);
473  	if (c->percpu_size) {
474  		c->low_watermark = 1;
475  		c->high_watermark = 3;
476  	} else if (c->unit_size <= 256) {
477  		c->low_watermark = 32;
478  		c->high_watermark = 96;
479  	} else {
480  		/* When page_size == 4k, order-0 cache will have low_mark == 2
481  		 * and high_mark == 6 with batch alloc of 3 individual pages at
482  		 * a time.
483  		 * 8k allocs and above low == 1, high == 3, batch == 1.
484  		 */
485  		c->low_watermark = max(32 * 256 / c->unit_size, 1);
486  		c->high_watermark = max(96 * 256 / c->unit_size, 3);
487  	}
488  	c->batch = max((c->high_watermark - c->low_watermark) / 4 * 3, 1);
489  }
490  
prefill_mem_cache(struct bpf_mem_cache * c,int cpu)491  static void prefill_mem_cache(struct bpf_mem_cache *c, int cpu)
492  {
493  	int cnt = 1;
494  
495  	/* To avoid consuming memory, for non-percpu allocation, assume that
496  	 * 1st run of bpf prog won't be doing more than 4 map_update_elem from
497  	 * irq disabled region if unit size is less than or equal to 256.
498  	 * For all other cases, let us just do one allocation.
499  	 */
500  	if (!c->percpu_size && c->unit_size <= 256)
501  		cnt = 4;
502  	alloc_bulk(c, cnt, cpu_to_node(cpu), false);
503  }
504  
505  /* When size != 0 bpf_mem_cache for each cpu.
506   * This is typical bpf hash map use case when all elements have equal size.
507   *
508   * When size == 0 allocate 11 bpf_mem_cache-s for each cpu, then rely on
509   * kmalloc/kfree. Max allocation size is 4096 in this case.
510   * This is bpf_dynptr and bpf_kptr use case.
511   */
bpf_mem_alloc_init(struct bpf_mem_alloc * ma,int size,bool percpu)512  int bpf_mem_alloc_init(struct bpf_mem_alloc *ma, int size, bool percpu)
513  {
514  	struct bpf_mem_caches *cc; struct bpf_mem_caches __percpu *pcc;
515  	struct bpf_mem_cache *c; struct bpf_mem_cache __percpu *pc;
516  	struct obj_cgroup *objcg = NULL;
517  	int cpu, i, unit_size, percpu_size = 0;
518  
519  	if (percpu && size == 0)
520  		return -EINVAL;
521  
522  	/* room for llist_node and per-cpu pointer */
523  	if (percpu)
524  		percpu_size = LLIST_NODE_SZ + sizeof(void *);
525  	ma->percpu = percpu;
526  
527  	if (size) {
528  		pc = __alloc_percpu_gfp(sizeof(*pc), 8, GFP_KERNEL);
529  		if (!pc)
530  			return -ENOMEM;
531  
532  		if (!percpu)
533  			size += LLIST_NODE_SZ; /* room for llist_node */
534  		unit_size = size;
535  
536  #ifdef CONFIG_MEMCG
537  		if (memcg_bpf_enabled())
538  			objcg = get_obj_cgroup_from_current();
539  #endif
540  		ma->objcg = objcg;
541  
542  		for_each_possible_cpu(cpu) {
543  			c = per_cpu_ptr(pc, cpu);
544  			c->unit_size = unit_size;
545  			c->objcg = objcg;
546  			c->percpu_size = percpu_size;
547  			c->tgt = c;
548  			init_refill_work(c);
549  			prefill_mem_cache(c, cpu);
550  		}
551  		ma->cache = pc;
552  		return 0;
553  	}
554  
555  	pcc = __alloc_percpu_gfp(sizeof(*cc), 8, GFP_KERNEL);
556  	if (!pcc)
557  		return -ENOMEM;
558  #ifdef CONFIG_MEMCG
559  	objcg = get_obj_cgroup_from_current();
560  #endif
561  	ma->objcg = objcg;
562  	for_each_possible_cpu(cpu) {
563  		cc = per_cpu_ptr(pcc, cpu);
564  		for (i = 0; i < NUM_CACHES; i++) {
565  			c = &cc->cache[i];
566  			c->unit_size = sizes[i];
567  			c->objcg = objcg;
568  			c->percpu_size = percpu_size;
569  			c->tgt = c;
570  
571  			init_refill_work(c);
572  			prefill_mem_cache(c, cpu);
573  		}
574  	}
575  
576  	ma->caches = pcc;
577  	return 0;
578  }
579  
bpf_mem_alloc_percpu_init(struct bpf_mem_alloc * ma,struct obj_cgroup * objcg)580  int bpf_mem_alloc_percpu_init(struct bpf_mem_alloc *ma, struct obj_cgroup *objcg)
581  {
582  	struct bpf_mem_caches __percpu *pcc;
583  
584  	pcc = __alloc_percpu_gfp(sizeof(struct bpf_mem_caches), 8, GFP_KERNEL);
585  	if (!pcc)
586  		return -ENOMEM;
587  
588  	ma->caches = pcc;
589  	ma->objcg = objcg;
590  	ma->percpu = true;
591  	return 0;
592  }
593  
bpf_mem_alloc_percpu_unit_init(struct bpf_mem_alloc * ma,int size)594  int bpf_mem_alloc_percpu_unit_init(struct bpf_mem_alloc *ma, int size)
595  {
596  	struct bpf_mem_caches *cc; struct bpf_mem_caches __percpu *pcc;
597  	int cpu, i, unit_size, percpu_size;
598  	struct obj_cgroup *objcg;
599  	struct bpf_mem_cache *c;
600  
601  	i = bpf_mem_cache_idx(size);
602  	if (i < 0)
603  		return -EINVAL;
604  
605  	/* room for llist_node and per-cpu pointer */
606  	percpu_size = LLIST_NODE_SZ + sizeof(void *);
607  
608  	unit_size = sizes[i];
609  	objcg = ma->objcg;
610  	pcc = ma->caches;
611  
612  	for_each_possible_cpu(cpu) {
613  		cc = per_cpu_ptr(pcc, cpu);
614  		c = &cc->cache[i];
615  		if (c->unit_size)
616  			break;
617  
618  		c->unit_size = unit_size;
619  		c->objcg = objcg;
620  		c->percpu_size = percpu_size;
621  		c->tgt = c;
622  
623  		init_refill_work(c);
624  		prefill_mem_cache(c, cpu);
625  	}
626  
627  	return 0;
628  }
629  
drain_mem_cache(struct bpf_mem_cache * c)630  static void drain_mem_cache(struct bpf_mem_cache *c)
631  {
632  	bool percpu = !!c->percpu_size;
633  
634  	/* No progs are using this bpf_mem_cache, but htab_map_free() called
635  	 * bpf_mem_cache_free() for all remaining elements and they can be in
636  	 * free_by_rcu_ttrace or in waiting_for_gp_ttrace lists, so drain those lists now.
637  	 *
638  	 * Except for waiting_for_gp_ttrace list, there are no concurrent operations
639  	 * on these lists, so it is safe to use __llist_del_all().
640  	 */
641  	free_all(llist_del_all(&c->free_by_rcu_ttrace), percpu);
642  	free_all(llist_del_all(&c->waiting_for_gp_ttrace), percpu);
643  	free_all(__llist_del_all(&c->free_llist), percpu);
644  	free_all(__llist_del_all(&c->free_llist_extra), percpu);
645  	free_all(__llist_del_all(&c->free_by_rcu), percpu);
646  	free_all(__llist_del_all(&c->free_llist_extra_rcu), percpu);
647  	free_all(llist_del_all(&c->waiting_for_gp), percpu);
648  }
649  
check_mem_cache(struct bpf_mem_cache * c)650  static void check_mem_cache(struct bpf_mem_cache *c)
651  {
652  	WARN_ON_ONCE(!llist_empty(&c->free_by_rcu_ttrace));
653  	WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp_ttrace));
654  	WARN_ON_ONCE(!llist_empty(&c->free_llist));
655  	WARN_ON_ONCE(!llist_empty(&c->free_llist_extra));
656  	WARN_ON_ONCE(!llist_empty(&c->free_by_rcu));
657  	WARN_ON_ONCE(!llist_empty(&c->free_llist_extra_rcu));
658  	WARN_ON_ONCE(!llist_empty(&c->waiting_for_gp));
659  }
660  
check_leaked_objs(struct bpf_mem_alloc * ma)661  static void check_leaked_objs(struct bpf_mem_alloc *ma)
662  {
663  	struct bpf_mem_caches *cc;
664  	struct bpf_mem_cache *c;
665  	int cpu, i;
666  
667  	if (ma->cache) {
668  		for_each_possible_cpu(cpu) {
669  			c = per_cpu_ptr(ma->cache, cpu);
670  			check_mem_cache(c);
671  		}
672  	}
673  	if (ma->caches) {
674  		for_each_possible_cpu(cpu) {
675  			cc = per_cpu_ptr(ma->caches, cpu);
676  			for (i = 0; i < NUM_CACHES; i++) {
677  				c = &cc->cache[i];
678  				check_mem_cache(c);
679  			}
680  		}
681  	}
682  }
683  
free_mem_alloc_no_barrier(struct bpf_mem_alloc * ma)684  static void free_mem_alloc_no_barrier(struct bpf_mem_alloc *ma)
685  {
686  	check_leaked_objs(ma);
687  	free_percpu(ma->cache);
688  	free_percpu(ma->caches);
689  	ma->cache = NULL;
690  	ma->caches = NULL;
691  }
692  
free_mem_alloc(struct bpf_mem_alloc * ma)693  static void free_mem_alloc(struct bpf_mem_alloc *ma)
694  {
695  	/* waiting_for_gp[_ttrace] lists were drained, but RCU callbacks
696  	 * might still execute. Wait for them.
697  	 *
698  	 * rcu_barrier_tasks_trace() doesn't imply synchronize_rcu_tasks_trace(),
699  	 * but rcu_barrier_tasks_trace() and rcu_barrier() below are only used
700  	 * to wait for the pending __free_rcu_tasks_trace() and __free_rcu(),
701  	 * so if call_rcu(head, __free_rcu) is skipped due to
702  	 * rcu_trace_implies_rcu_gp(), it will be OK to skip rcu_barrier() by
703  	 * using rcu_trace_implies_rcu_gp() as well.
704  	 */
705  	rcu_barrier(); /* wait for __free_by_rcu */
706  	rcu_barrier_tasks_trace(); /* wait for __free_rcu */
707  	if (!rcu_trace_implies_rcu_gp())
708  		rcu_barrier();
709  	free_mem_alloc_no_barrier(ma);
710  }
711  
free_mem_alloc_deferred(struct work_struct * work)712  static void free_mem_alloc_deferred(struct work_struct *work)
713  {
714  	struct bpf_mem_alloc *ma = container_of(work, struct bpf_mem_alloc, work);
715  
716  	free_mem_alloc(ma);
717  	kfree(ma);
718  }
719  
destroy_mem_alloc(struct bpf_mem_alloc * ma,int rcu_in_progress)720  static void destroy_mem_alloc(struct bpf_mem_alloc *ma, int rcu_in_progress)
721  {
722  	struct bpf_mem_alloc *copy;
723  
724  	if (!rcu_in_progress) {
725  		/* Fast path. No callbacks are pending, hence no need to do
726  		 * rcu_barrier-s.
727  		 */
728  		free_mem_alloc_no_barrier(ma);
729  		return;
730  	}
731  
732  	copy = kmemdup(ma, sizeof(*ma), GFP_KERNEL);
733  	if (!copy) {
734  		/* Slow path with inline barrier-s */
735  		free_mem_alloc(ma);
736  		return;
737  	}
738  
739  	/* Defer barriers into worker to let the rest of map memory to be freed */
740  	memset(ma, 0, sizeof(*ma));
741  	INIT_WORK(&copy->work, free_mem_alloc_deferred);
742  	queue_work(system_unbound_wq, &copy->work);
743  }
744  
bpf_mem_alloc_destroy(struct bpf_mem_alloc * ma)745  void bpf_mem_alloc_destroy(struct bpf_mem_alloc *ma)
746  {
747  	struct bpf_mem_caches *cc;
748  	struct bpf_mem_cache *c;
749  	int cpu, i, rcu_in_progress;
750  
751  	if (ma->cache) {
752  		rcu_in_progress = 0;
753  		for_each_possible_cpu(cpu) {
754  			c = per_cpu_ptr(ma->cache, cpu);
755  			WRITE_ONCE(c->draining, true);
756  			irq_work_sync(&c->refill_work);
757  			drain_mem_cache(c);
758  			rcu_in_progress += atomic_read(&c->call_rcu_ttrace_in_progress);
759  			rcu_in_progress += atomic_read(&c->call_rcu_in_progress);
760  		}
761  		obj_cgroup_put(ma->objcg);
762  		destroy_mem_alloc(ma, rcu_in_progress);
763  	}
764  	if (ma->caches) {
765  		rcu_in_progress = 0;
766  		for_each_possible_cpu(cpu) {
767  			cc = per_cpu_ptr(ma->caches, cpu);
768  			for (i = 0; i < NUM_CACHES; i++) {
769  				c = &cc->cache[i];
770  				WRITE_ONCE(c->draining, true);
771  				irq_work_sync(&c->refill_work);
772  				drain_mem_cache(c);
773  				rcu_in_progress += atomic_read(&c->call_rcu_ttrace_in_progress);
774  				rcu_in_progress += atomic_read(&c->call_rcu_in_progress);
775  			}
776  		}
777  		obj_cgroup_put(ma->objcg);
778  		destroy_mem_alloc(ma, rcu_in_progress);
779  	}
780  }
781  
782  /* notrace is necessary here and in other functions to make sure
783   * bpf programs cannot attach to them and cause llist corruptions.
784   */
unit_alloc(struct bpf_mem_cache * c)785  static void notrace *unit_alloc(struct bpf_mem_cache *c)
786  {
787  	struct llist_node *llnode = NULL;
788  	unsigned long flags;
789  	int cnt = 0;
790  
791  	/* Disable irqs to prevent the following race for majority of prog types:
792  	 * prog_A
793  	 *   bpf_mem_alloc
794  	 *      preemption or irq -> prog_B
795  	 *        bpf_mem_alloc
796  	 *
797  	 * but prog_B could be a perf_event NMI prog.
798  	 * Use per-cpu 'active' counter to order free_list access between
799  	 * unit_alloc/unit_free/bpf_mem_refill.
800  	 */
801  	local_irq_save(flags);
802  	if (local_inc_return(&c->active) == 1) {
803  		llnode = __llist_del_first(&c->free_llist);
804  		if (llnode) {
805  			cnt = --c->free_cnt;
806  			*(struct bpf_mem_cache **)llnode = c;
807  		}
808  	}
809  	local_dec(&c->active);
810  
811  	WARN_ON(cnt < 0);
812  
813  	if (cnt < c->low_watermark)
814  		irq_work_raise(c);
815  	/* Enable IRQ after the enqueue of irq work completes, so irq work
816  	 * will run after IRQ is enabled and free_llist may be refilled by
817  	 * irq work before other task preempts current task.
818  	 */
819  	local_irq_restore(flags);
820  
821  	return llnode;
822  }
823  
824  /* Though 'ptr' object could have been allocated on a different cpu
825   * add it to the free_llist of the current cpu.
826   * Let kfree() logic deal with it when it's later called from irq_work.
827   */
unit_free(struct bpf_mem_cache * c,void * ptr)828  static void notrace unit_free(struct bpf_mem_cache *c, void *ptr)
829  {
830  	struct llist_node *llnode = ptr - LLIST_NODE_SZ;
831  	unsigned long flags;
832  	int cnt = 0;
833  
834  	BUILD_BUG_ON(LLIST_NODE_SZ > 8);
835  
836  	/*
837  	 * Remember bpf_mem_cache that allocated this object.
838  	 * The hint is not accurate.
839  	 */
840  	c->tgt = *(struct bpf_mem_cache **)llnode;
841  
842  	local_irq_save(flags);
843  	if (local_inc_return(&c->active) == 1) {
844  		__llist_add(llnode, &c->free_llist);
845  		cnt = ++c->free_cnt;
846  	} else {
847  		/* unit_free() cannot fail. Therefore add an object to atomic
848  		 * llist. free_bulk() will drain it. Though free_llist_extra is
849  		 * a per-cpu list we have to use atomic llist_add here, since
850  		 * it also can be interrupted by bpf nmi prog that does another
851  		 * unit_free() into the same free_llist_extra.
852  		 */
853  		llist_add(llnode, &c->free_llist_extra);
854  	}
855  	local_dec(&c->active);
856  
857  	if (cnt > c->high_watermark)
858  		/* free few objects from current cpu into global kmalloc pool */
859  		irq_work_raise(c);
860  	/* Enable IRQ after irq_work_raise() completes, otherwise when current
861  	 * task is preempted by task which does unit_alloc(), unit_alloc() may
862  	 * return NULL unexpectedly because irq work is already pending but can
863  	 * not been triggered and free_llist can not be refilled timely.
864  	 */
865  	local_irq_restore(flags);
866  }
867  
unit_free_rcu(struct bpf_mem_cache * c,void * ptr)868  static void notrace unit_free_rcu(struct bpf_mem_cache *c, void *ptr)
869  {
870  	struct llist_node *llnode = ptr - LLIST_NODE_SZ;
871  	unsigned long flags;
872  
873  	c->tgt = *(struct bpf_mem_cache **)llnode;
874  
875  	local_irq_save(flags);
876  	if (local_inc_return(&c->active) == 1) {
877  		if (__llist_add(llnode, &c->free_by_rcu))
878  			c->free_by_rcu_tail = llnode;
879  	} else {
880  		llist_add(llnode, &c->free_llist_extra_rcu);
881  	}
882  	local_dec(&c->active);
883  
884  	if (!atomic_read(&c->call_rcu_in_progress))
885  		irq_work_raise(c);
886  	local_irq_restore(flags);
887  }
888  
889  /* Called from BPF program or from sys_bpf syscall.
890   * In both cases migration is disabled.
891   */
bpf_mem_alloc(struct bpf_mem_alloc * ma,size_t size)892  void notrace *bpf_mem_alloc(struct bpf_mem_alloc *ma, size_t size)
893  {
894  	int idx;
895  	void *ret;
896  
897  	if (!size)
898  		return NULL;
899  
900  	if (!ma->percpu)
901  		size += LLIST_NODE_SZ;
902  	idx = bpf_mem_cache_idx(size);
903  	if (idx < 0)
904  		return NULL;
905  
906  	ret = unit_alloc(this_cpu_ptr(ma->caches)->cache + idx);
907  	return !ret ? NULL : ret + LLIST_NODE_SZ;
908  }
909  
bpf_mem_free(struct bpf_mem_alloc * ma,void * ptr)910  void notrace bpf_mem_free(struct bpf_mem_alloc *ma, void *ptr)
911  {
912  	struct bpf_mem_cache *c;
913  	int idx;
914  
915  	if (!ptr)
916  		return;
917  
918  	c = *(void **)(ptr - LLIST_NODE_SZ);
919  	idx = bpf_mem_cache_idx(c->unit_size);
920  	if (WARN_ON_ONCE(idx < 0))
921  		return;
922  
923  	unit_free(this_cpu_ptr(ma->caches)->cache + idx, ptr);
924  }
925  
bpf_mem_free_rcu(struct bpf_mem_alloc * ma,void * ptr)926  void notrace bpf_mem_free_rcu(struct bpf_mem_alloc *ma, void *ptr)
927  {
928  	struct bpf_mem_cache *c;
929  	int idx;
930  
931  	if (!ptr)
932  		return;
933  
934  	c = *(void **)(ptr - LLIST_NODE_SZ);
935  	idx = bpf_mem_cache_idx(c->unit_size);
936  	if (WARN_ON_ONCE(idx < 0))
937  		return;
938  
939  	unit_free_rcu(this_cpu_ptr(ma->caches)->cache + idx, ptr);
940  }
941  
bpf_mem_cache_alloc(struct bpf_mem_alloc * ma)942  void notrace *bpf_mem_cache_alloc(struct bpf_mem_alloc *ma)
943  {
944  	void *ret;
945  
946  	ret = unit_alloc(this_cpu_ptr(ma->cache));
947  	return !ret ? NULL : ret + LLIST_NODE_SZ;
948  }
949  
bpf_mem_cache_free(struct bpf_mem_alloc * ma,void * ptr)950  void notrace bpf_mem_cache_free(struct bpf_mem_alloc *ma, void *ptr)
951  {
952  	if (!ptr)
953  		return;
954  
955  	unit_free(this_cpu_ptr(ma->cache), ptr);
956  }
957  
bpf_mem_cache_free_rcu(struct bpf_mem_alloc * ma,void * ptr)958  void notrace bpf_mem_cache_free_rcu(struct bpf_mem_alloc *ma, void *ptr)
959  {
960  	if (!ptr)
961  		return;
962  
963  	unit_free_rcu(this_cpu_ptr(ma->cache), ptr);
964  }
965  
966  /* Directly does a kfree() without putting 'ptr' back to the free_llist
967   * for reuse and without waiting for a rcu_tasks_trace gp.
968   * The caller must first go through the rcu_tasks_trace gp for 'ptr'
969   * before calling bpf_mem_cache_raw_free().
970   * It could be used when the rcu_tasks_trace callback does not have
971   * a hold on the original bpf_mem_alloc object that allocated the
972   * 'ptr'. This should only be used in the uncommon code path.
973   * Otherwise, the bpf_mem_alloc's free_llist cannot be refilled
974   * and may affect performance.
975   */
bpf_mem_cache_raw_free(void * ptr)976  void bpf_mem_cache_raw_free(void *ptr)
977  {
978  	if (!ptr)
979  		return;
980  
981  	kfree(ptr - LLIST_NODE_SZ);
982  }
983  
984  /* When flags == GFP_KERNEL, it signals that the caller will not cause
985   * deadlock when using kmalloc. bpf_mem_cache_alloc_flags() will use
986   * kmalloc if the free_llist is empty.
987   */
bpf_mem_cache_alloc_flags(struct bpf_mem_alloc * ma,gfp_t flags)988  void notrace *bpf_mem_cache_alloc_flags(struct bpf_mem_alloc *ma, gfp_t flags)
989  {
990  	struct bpf_mem_cache *c;
991  	void *ret;
992  
993  	c = this_cpu_ptr(ma->cache);
994  
995  	ret = unit_alloc(c);
996  	if (!ret && flags == GFP_KERNEL) {
997  		struct mem_cgroup *memcg, *old_memcg;
998  
999  		memcg = get_memcg(c);
1000  		old_memcg = set_active_memcg(memcg);
1001  		ret = __alloc(c, NUMA_NO_NODE, GFP_KERNEL | __GFP_NOWARN | __GFP_ACCOUNT);
1002  		if (ret)
1003  			*(struct bpf_mem_cache **)ret = c;
1004  		set_active_memcg(old_memcg);
1005  		mem_cgroup_put(memcg);
1006  	}
1007  
1008  	return !ret ? NULL : ret + LLIST_NODE_SZ;
1009  }
1010  
bpf_mem_alloc_check_size(bool percpu,size_t size)1011  int bpf_mem_alloc_check_size(bool percpu, size_t size)
1012  {
1013  	/* The size of percpu allocation doesn't have LLIST_NODE_SZ overhead */
1014  	if ((percpu && size > BPF_MEM_ALLOC_SIZE_MAX) ||
1015  	    (!percpu && size > BPF_MEM_ALLOC_SIZE_MAX - LLIST_NODE_SZ))
1016  		return -E2BIG;
1017  
1018  	return 0;
1019  }
1020