1  // SPDX-License-Identifier: GPL-2.0
2  /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3   *
4   * The iopt_pages is the center of the storage and motion of PFNs. Each
5   * iopt_pages represents a logical linear array of full PFNs. The array is 0
6   * based and has npages in it. Accessors use 'index' to refer to the entry in
7   * this logical array, regardless of its storage location.
8   *
9   * PFNs are stored in a tiered scheme:
10   *  1) iopt_pages::pinned_pfns xarray
11   *  2) An iommu_domain
12   *  3) The origin of the PFNs, i.e. the userspace pointer
13   *
14   * PFN have to be copied between all combinations of tiers, depending on the
15   * configuration.
16   *
17   * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18   * The storage locations of the PFN's index are tracked in the two interval
19   * trees. If no interval includes the index then it is not pinned.
20   *
21   * If access_itree includes the PFN's index then an in-kernel access has
22   * requested the page. The PFN is stored in the xarray so other requestors can
23   * continue to find it.
24   *
25   * If the domains_itree includes the PFN's index then an iommu_domain is storing
26   * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27   * duplicating storage the xarray is not used if only iommu_domains are using
28   * the PFN's index.
29   *
30   * As a general principle this is designed so that destroy never fails. This
31   * means removing an iommu_domain or releasing a in-kernel access will not fail
32   * due to insufficient memory. In practice this means some cases have to hold
33   * PFNs in the xarray even though they are also being stored in an iommu_domain.
34   *
35   * While the iopt_pages can use an iommu_domain as storage, it does not have an
36   * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37   * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38   * and reference their own slice of the PFN array, with sub page granularity.
39   *
40   * In this file the term 'last' indicates an inclusive and closed interval, eg
41   * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42   * no PFNs.
43   *
44   * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45   * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46   * ULONG_MAX so last_index + 1 cannot overflow.
47   */
48  #include <linux/highmem.h>
49  #include <linux/iommu.h>
50  #include <linux/iommufd.h>
51  #include <linux/kthread.h>
52  #include <linux/overflow.h>
53  #include <linux/slab.h>
54  #include <linux/sched/mm.h>
55  
56  #include "double_span.h"
57  #include "io_pagetable.h"
58  
59  #ifndef CONFIG_IOMMUFD_TEST
60  #define TEMP_MEMORY_LIMIT 65536
61  #else
62  #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63  #endif
64  #define BATCH_BACKUP_SIZE 32
65  
66  /*
67   * More memory makes pin_user_pages() and the batching more efficient, but as
68   * this is only a performance optimization don't try too hard to get it. A 64k
69   * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70   * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71   * stack memory as a backup contingency. If backup_len is given this cannot
72   * fail.
73   */
temp_kmalloc(size_t * size,void * backup,size_t backup_len)74  static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75  {
76  	void *res;
77  
78  	if (WARN_ON(*size == 0))
79  		return NULL;
80  
81  	if (*size < backup_len)
82  		return backup;
83  
84  	if (!backup && iommufd_should_fail())
85  		return NULL;
86  
87  	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
88  	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
89  	if (res)
90  		return res;
91  	*size = PAGE_SIZE;
92  	if (backup_len) {
93  		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
94  		if (res)
95  			return res;
96  		*size = backup_len;
97  		return backup;
98  	}
99  	return kmalloc(*size, GFP_KERNEL);
100  }
101  
interval_tree_double_span_iter_update(struct interval_tree_double_span_iter * iter)102  void interval_tree_double_span_iter_update(
103  	struct interval_tree_double_span_iter *iter)
104  {
105  	unsigned long last_hole = ULONG_MAX;
106  	unsigned int i;
107  
108  	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
109  		if (interval_tree_span_iter_done(&iter->spans[i])) {
110  			iter->is_used = -1;
111  			return;
112  		}
113  
114  		if (iter->spans[i].is_hole) {
115  			last_hole = min(last_hole, iter->spans[i].last_hole);
116  			continue;
117  		}
118  
119  		iter->is_used = i + 1;
120  		iter->start_used = iter->spans[i].start_used;
121  		iter->last_used = min(iter->spans[i].last_used, last_hole);
122  		return;
123  	}
124  
125  	iter->is_used = 0;
126  	iter->start_hole = iter->spans[0].start_hole;
127  	iter->last_hole =
128  		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
129  }
130  
interval_tree_double_span_iter_first(struct interval_tree_double_span_iter * iter,struct rb_root_cached * itree1,struct rb_root_cached * itree2,unsigned long first_index,unsigned long last_index)131  void interval_tree_double_span_iter_first(
132  	struct interval_tree_double_span_iter *iter,
133  	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
134  	unsigned long first_index, unsigned long last_index)
135  {
136  	unsigned int i;
137  
138  	iter->itrees[0] = itree1;
139  	iter->itrees[1] = itree2;
140  	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
141  		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
142  					      first_index, last_index);
143  	interval_tree_double_span_iter_update(iter);
144  }
145  
interval_tree_double_span_iter_next(struct interval_tree_double_span_iter * iter)146  void interval_tree_double_span_iter_next(
147  	struct interval_tree_double_span_iter *iter)
148  {
149  	unsigned int i;
150  
151  	if (iter->is_used == -1 ||
152  	    iter->last_hole == iter->spans[0].last_index) {
153  		iter->is_used = -1;
154  		return;
155  	}
156  
157  	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
158  		interval_tree_span_iter_advance(
159  			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
160  	interval_tree_double_span_iter_update(iter);
161  }
162  
iopt_pages_add_npinned(struct iopt_pages * pages,size_t npages)163  static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
164  {
165  	int rc;
166  
167  	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
168  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
169  		WARN_ON(rc || pages->npinned > pages->npages);
170  }
171  
iopt_pages_sub_npinned(struct iopt_pages * pages,size_t npages)172  static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
173  {
174  	int rc;
175  
176  	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
177  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
178  		WARN_ON(rc || pages->npinned > pages->npages);
179  }
180  
iopt_pages_err_unpin(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** page_list)181  static void iopt_pages_err_unpin(struct iopt_pages *pages,
182  				 unsigned long start_index,
183  				 unsigned long last_index,
184  				 struct page **page_list)
185  {
186  	unsigned long npages = last_index - start_index + 1;
187  
188  	unpin_user_pages(page_list, npages);
189  	iopt_pages_sub_npinned(pages, npages);
190  }
191  
192  /*
193   * index is the number of PAGE_SIZE units from the start of the area's
194   * iopt_pages. If the iova is sub page-size then the area has an iova that
195   * covers a portion of the first and last pages in the range.
196   */
iopt_area_index_to_iova(struct iopt_area * area,unsigned long index)197  static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
198  					     unsigned long index)
199  {
200  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
201  		WARN_ON(index < iopt_area_index(area) ||
202  			index > iopt_area_last_index(area));
203  	index -= iopt_area_index(area);
204  	if (index == 0)
205  		return iopt_area_iova(area);
206  	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
207  }
208  
iopt_area_index_to_iova_last(struct iopt_area * area,unsigned long index)209  static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
210  						  unsigned long index)
211  {
212  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
213  		WARN_ON(index < iopt_area_index(area) ||
214  			index > iopt_area_last_index(area));
215  	if (index == iopt_area_last_index(area))
216  		return iopt_area_last_iova(area);
217  	return iopt_area_iova(area) - area->page_offset +
218  	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
219  }
220  
iommu_unmap_nofail(struct iommu_domain * domain,unsigned long iova,size_t size)221  static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
222  			       size_t size)
223  {
224  	size_t ret;
225  
226  	ret = iommu_unmap(domain, iova, size);
227  	/*
228  	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
229  	 * something other than exactly as requested. This implies that the
230  	 * iommu driver may not fail unmap for reasons beyond bad agruments.
231  	 * Particularly, the iommu driver may not do a memory allocation on the
232  	 * unmap path.
233  	 */
234  	WARN_ON(ret != size);
235  }
236  
iopt_area_unmap_domain_range(struct iopt_area * area,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index)237  static void iopt_area_unmap_domain_range(struct iopt_area *area,
238  					 struct iommu_domain *domain,
239  					 unsigned long start_index,
240  					 unsigned long last_index)
241  {
242  	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
243  
244  	iommu_unmap_nofail(domain, start_iova,
245  			   iopt_area_index_to_iova_last(area, last_index) -
246  				   start_iova + 1);
247  }
248  
iopt_pages_find_domain_area(struct iopt_pages * pages,unsigned long index)249  static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
250  						     unsigned long index)
251  {
252  	struct interval_tree_node *node;
253  
254  	node = interval_tree_iter_first(&pages->domains_itree, index, index);
255  	if (!node)
256  		return NULL;
257  	return container_of(node, struct iopt_area, pages_node);
258  }
259  
260  /*
261   * A simple datastructure to hold a vector of PFNs, optimized for contiguous
262   * PFNs. This is used as a temporary holding memory for shuttling pfns from one
263   * place to another. Generally everything is made more efficient if operations
264   * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
265   * better cache locality, etc
266   */
267  struct pfn_batch {
268  	unsigned long *pfns;
269  	u32 *npfns;
270  	unsigned int array_size;
271  	unsigned int end;
272  	unsigned int total_pfns;
273  };
274  
batch_clear(struct pfn_batch * batch)275  static void batch_clear(struct pfn_batch *batch)
276  {
277  	batch->total_pfns = 0;
278  	batch->end = 0;
279  	batch->pfns[0] = 0;
280  	batch->npfns[0] = 0;
281  }
282  
283  /*
284   * Carry means we carry a portion of the final hugepage over to the front of the
285   * batch
286   */
batch_clear_carry(struct pfn_batch * batch,unsigned int keep_pfns)287  static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
288  {
289  	if (!keep_pfns)
290  		return batch_clear(batch);
291  
292  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
293  		WARN_ON(!batch->end ||
294  			batch->npfns[batch->end - 1] < keep_pfns);
295  
296  	batch->total_pfns = keep_pfns;
297  	batch->pfns[0] = batch->pfns[batch->end - 1] +
298  			 (batch->npfns[batch->end - 1] - keep_pfns);
299  	batch->npfns[0] = keep_pfns;
300  	batch->end = 1;
301  }
302  
batch_skip_carry(struct pfn_batch * batch,unsigned int skip_pfns)303  static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
304  {
305  	if (!batch->total_pfns)
306  		return;
307  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
308  		WARN_ON(batch->total_pfns != batch->npfns[0]);
309  	skip_pfns = min(batch->total_pfns, skip_pfns);
310  	batch->pfns[0] += skip_pfns;
311  	batch->npfns[0] -= skip_pfns;
312  	batch->total_pfns -= skip_pfns;
313  }
314  
__batch_init(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)315  static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
316  			size_t backup_len)
317  {
318  	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
319  	size_t size = max_pages * elmsz;
320  
321  	batch->pfns = temp_kmalloc(&size, backup, backup_len);
322  	if (!batch->pfns)
323  		return -ENOMEM;
324  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
325  		return -EINVAL;
326  	batch->array_size = size / elmsz;
327  	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
328  	batch_clear(batch);
329  	return 0;
330  }
331  
batch_init(struct pfn_batch * batch,size_t max_pages)332  static int batch_init(struct pfn_batch *batch, size_t max_pages)
333  {
334  	return __batch_init(batch, max_pages, NULL, 0);
335  }
336  
batch_init_backup(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)337  static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
338  			      void *backup, size_t backup_len)
339  {
340  	__batch_init(batch, max_pages, backup, backup_len);
341  }
342  
batch_destroy(struct pfn_batch * batch,void * backup)343  static void batch_destroy(struct pfn_batch *batch, void *backup)
344  {
345  	if (batch->pfns != backup)
346  		kfree(batch->pfns);
347  }
348  
349  /* true if the pfn was added, false otherwise */
batch_add_pfn(struct pfn_batch * batch,unsigned long pfn)350  static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
351  {
352  	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
353  
354  	if (batch->end &&
355  	    pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
356  	    batch->npfns[batch->end - 1] != MAX_NPFNS) {
357  		batch->npfns[batch->end - 1]++;
358  		batch->total_pfns++;
359  		return true;
360  	}
361  	if (batch->end == batch->array_size)
362  		return false;
363  	batch->total_pfns++;
364  	batch->pfns[batch->end] = pfn;
365  	batch->npfns[batch->end] = 1;
366  	batch->end++;
367  	return true;
368  }
369  
370  /*
371   * Fill the batch with pfns from the domain. When the batch is full, or it
372   * reaches last_index, the function will return. The caller should use
373   * batch->total_pfns to determine the starting point for the next iteration.
374   */
batch_from_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)375  static void batch_from_domain(struct pfn_batch *batch,
376  			      struct iommu_domain *domain,
377  			      struct iopt_area *area, unsigned long start_index,
378  			      unsigned long last_index)
379  {
380  	unsigned int page_offset = 0;
381  	unsigned long iova;
382  	phys_addr_t phys;
383  
384  	iova = iopt_area_index_to_iova(area, start_index);
385  	if (start_index == iopt_area_index(area))
386  		page_offset = area->page_offset;
387  	while (start_index <= last_index) {
388  		/*
389  		 * This is pretty slow, it would be nice to get the page size
390  		 * back from the driver, or have the driver directly fill the
391  		 * batch.
392  		 */
393  		phys = iommu_iova_to_phys(domain, iova) - page_offset;
394  		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
395  			return;
396  		iova += PAGE_SIZE - page_offset;
397  		page_offset = 0;
398  		start_index++;
399  	}
400  }
401  
raw_pages_from_domain(struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages)402  static struct page **raw_pages_from_domain(struct iommu_domain *domain,
403  					   struct iopt_area *area,
404  					   unsigned long start_index,
405  					   unsigned long last_index,
406  					   struct page **out_pages)
407  {
408  	unsigned int page_offset = 0;
409  	unsigned long iova;
410  	phys_addr_t phys;
411  
412  	iova = iopt_area_index_to_iova(area, start_index);
413  	if (start_index == iopt_area_index(area))
414  		page_offset = area->page_offset;
415  	while (start_index <= last_index) {
416  		phys = iommu_iova_to_phys(domain, iova) - page_offset;
417  		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
418  		iova += PAGE_SIZE - page_offset;
419  		page_offset = 0;
420  		start_index++;
421  	}
422  	return out_pages;
423  }
424  
425  /* Continues reading a domain until we reach a discontinuity in the pfns. */
batch_from_domain_continue(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)426  static void batch_from_domain_continue(struct pfn_batch *batch,
427  				       struct iommu_domain *domain,
428  				       struct iopt_area *area,
429  				       unsigned long start_index,
430  				       unsigned long last_index)
431  {
432  	unsigned int array_size = batch->array_size;
433  
434  	batch->array_size = batch->end;
435  	batch_from_domain(batch, domain, area, start_index, last_index);
436  	batch->array_size = array_size;
437  }
438  
439  /*
440   * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
441   * mode permits splitting a mapped area up, and then one of the splits is
442   * unmapped. Doing this normally would cause us to violate our invariant of
443   * pairing map/unmap. Thus, to support old VFIO compatibility disable support
444   * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
445   * PAGE_SIZE units, not larger or smaller.
446   */
batch_iommu_map_small(struct iommu_domain * domain,unsigned long iova,phys_addr_t paddr,size_t size,int prot)447  static int batch_iommu_map_small(struct iommu_domain *domain,
448  				 unsigned long iova, phys_addr_t paddr,
449  				 size_t size, int prot)
450  {
451  	unsigned long start_iova = iova;
452  	int rc;
453  
454  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
455  		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
456  			size % PAGE_SIZE);
457  
458  	while (size) {
459  		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
460  			       GFP_KERNEL_ACCOUNT);
461  		if (rc)
462  			goto err_unmap;
463  		iova += PAGE_SIZE;
464  		paddr += PAGE_SIZE;
465  		size -= PAGE_SIZE;
466  	}
467  	return 0;
468  
469  err_unmap:
470  	if (start_iova != iova)
471  		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
472  	return rc;
473  }
474  
batch_to_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index)475  static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
476  			   struct iopt_area *area, unsigned long start_index)
477  {
478  	bool disable_large_pages = area->iopt->disable_large_pages;
479  	unsigned long last_iova = iopt_area_last_iova(area);
480  	unsigned int page_offset = 0;
481  	unsigned long start_iova;
482  	unsigned long next_iova;
483  	unsigned int cur = 0;
484  	unsigned long iova;
485  	int rc;
486  
487  	/* The first index might be a partial page */
488  	if (start_index == iopt_area_index(area))
489  		page_offset = area->page_offset;
490  	next_iova = iova = start_iova =
491  		iopt_area_index_to_iova(area, start_index);
492  	while (cur < batch->end) {
493  		next_iova = min(last_iova + 1,
494  				next_iova + batch->npfns[cur] * PAGE_SIZE -
495  					page_offset);
496  		if (disable_large_pages)
497  			rc = batch_iommu_map_small(
498  				domain, iova,
499  				PFN_PHYS(batch->pfns[cur]) + page_offset,
500  				next_iova - iova, area->iommu_prot);
501  		else
502  			rc = iommu_map(domain, iova,
503  				       PFN_PHYS(batch->pfns[cur]) + page_offset,
504  				       next_iova - iova, area->iommu_prot,
505  				       GFP_KERNEL_ACCOUNT);
506  		if (rc)
507  			goto err_unmap;
508  		iova = next_iova;
509  		page_offset = 0;
510  		cur++;
511  	}
512  	return 0;
513  err_unmap:
514  	if (start_iova != iova)
515  		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
516  	return rc;
517  }
518  
batch_from_xarray(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)519  static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
520  			      unsigned long start_index,
521  			      unsigned long last_index)
522  {
523  	XA_STATE(xas, xa, start_index);
524  	void *entry;
525  
526  	rcu_read_lock();
527  	while (true) {
528  		entry = xas_next(&xas);
529  		if (xas_retry(&xas, entry))
530  			continue;
531  		WARN_ON(!xa_is_value(entry));
532  		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
533  		    start_index == last_index)
534  			break;
535  		start_index++;
536  	}
537  	rcu_read_unlock();
538  }
539  
batch_from_xarray_clear(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)540  static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
541  				    unsigned long start_index,
542  				    unsigned long last_index)
543  {
544  	XA_STATE(xas, xa, start_index);
545  	void *entry;
546  
547  	xas_lock(&xas);
548  	while (true) {
549  		entry = xas_next(&xas);
550  		if (xas_retry(&xas, entry))
551  			continue;
552  		WARN_ON(!xa_is_value(entry));
553  		if (!batch_add_pfn(batch, xa_to_value(entry)))
554  			break;
555  		xas_store(&xas, NULL);
556  		if (start_index == last_index)
557  			break;
558  		start_index++;
559  	}
560  	xas_unlock(&xas);
561  }
562  
clear_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index)563  static void clear_xarray(struct xarray *xa, unsigned long start_index,
564  			 unsigned long last_index)
565  {
566  	XA_STATE(xas, xa, start_index);
567  	void *entry;
568  
569  	xas_lock(&xas);
570  	xas_for_each(&xas, entry, last_index)
571  		xas_store(&xas, NULL);
572  	xas_unlock(&xas);
573  }
574  
pages_to_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index,struct page ** pages)575  static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
576  			   unsigned long last_index, struct page **pages)
577  {
578  	struct page **end_pages = pages + (last_index - start_index) + 1;
579  	struct page **half_pages = pages + (end_pages - pages) / 2;
580  	XA_STATE(xas, xa, start_index);
581  
582  	do {
583  		void *old;
584  
585  		xas_lock(&xas);
586  		while (pages != end_pages) {
587  			/* xarray does not participate in fault injection */
588  			if (pages == half_pages && iommufd_should_fail()) {
589  				xas_set_err(&xas, -EINVAL);
590  				xas_unlock(&xas);
591  				/* aka xas_destroy() */
592  				xas_nomem(&xas, GFP_KERNEL);
593  				goto err_clear;
594  			}
595  
596  			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
597  			if (xas_error(&xas))
598  				break;
599  			WARN_ON(old);
600  			pages++;
601  			xas_next(&xas);
602  		}
603  		xas_unlock(&xas);
604  	} while (xas_nomem(&xas, GFP_KERNEL));
605  
606  err_clear:
607  	if (xas_error(&xas)) {
608  		if (xas.xa_index != start_index)
609  			clear_xarray(xa, start_index, xas.xa_index - 1);
610  		return xas_error(&xas);
611  	}
612  	return 0;
613  }
614  
batch_from_pages(struct pfn_batch * batch,struct page ** pages,size_t npages)615  static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
616  			     size_t npages)
617  {
618  	struct page **end = pages + npages;
619  
620  	for (; pages != end; pages++)
621  		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
622  			break;
623  }
624  
batch_unpin(struct pfn_batch * batch,struct iopt_pages * pages,unsigned int first_page_off,size_t npages)625  static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
626  			unsigned int first_page_off, size_t npages)
627  {
628  	unsigned int cur = 0;
629  
630  	while (first_page_off) {
631  		if (batch->npfns[cur] > first_page_off)
632  			break;
633  		first_page_off -= batch->npfns[cur];
634  		cur++;
635  	}
636  
637  	while (npages) {
638  		size_t to_unpin = min_t(size_t, npages,
639  					batch->npfns[cur] - first_page_off);
640  
641  		unpin_user_page_range_dirty_lock(
642  			pfn_to_page(batch->pfns[cur] + first_page_off),
643  			to_unpin, pages->writable);
644  		iopt_pages_sub_npinned(pages, to_unpin);
645  		cur++;
646  		first_page_off = 0;
647  		npages -= to_unpin;
648  	}
649  }
650  
copy_data_page(struct page * page,void * data,unsigned long offset,size_t length,unsigned int flags)651  static void copy_data_page(struct page *page, void *data, unsigned long offset,
652  			   size_t length, unsigned int flags)
653  {
654  	void *mem;
655  
656  	mem = kmap_local_page(page);
657  	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
658  		memcpy(mem + offset, data, length);
659  		set_page_dirty_lock(page);
660  	} else {
661  		memcpy(data, mem + offset, length);
662  	}
663  	kunmap_local(mem);
664  }
665  
batch_rw(struct pfn_batch * batch,void * data,unsigned long offset,unsigned long length,unsigned int flags)666  static unsigned long batch_rw(struct pfn_batch *batch, void *data,
667  			      unsigned long offset, unsigned long length,
668  			      unsigned int flags)
669  {
670  	unsigned long copied = 0;
671  	unsigned int npage = 0;
672  	unsigned int cur = 0;
673  
674  	while (cur < batch->end) {
675  		unsigned long bytes = min(length, PAGE_SIZE - offset);
676  
677  		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
678  			       offset, bytes, flags);
679  		offset = 0;
680  		length -= bytes;
681  		data += bytes;
682  		copied += bytes;
683  		npage++;
684  		if (npage == batch->npfns[cur]) {
685  			npage = 0;
686  			cur++;
687  		}
688  		if (!length)
689  			break;
690  	}
691  	return copied;
692  }
693  
694  /* pfn_reader_user is just the pin_user_pages() path */
695  struct pfn_reader_user {
696  	struct page **upages;
697  	size_t upages_len;
698  	unsigned long upages_start;
699  	unsigned long upages_end;
700  	unsigned int gup_flags;
701  	/*
702  	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
703  	 * neither
704  	 */
705  	int locked;
706  };
707  
pfn_reader_user_init(struct pfn_reader_user * user,struct iopt_pages * pages)708  static void pfn_reader_user_init(struct pfn_reader_user *user,
709  				 struct iopt_pages *pages)
710  {
711  	user->upages = NULL;
712  	user->upages_start = 0;
713  	user->upages_end = 0;
714  	user->locked = -1;
715  
716  	user->gup_flags = FOLL_LONGTERM;
717  	if (pages->writable)
718  		user->gup_flags |= FOLL_WRITE;
719  }
720  
pfn_reader_user_destroy(struct pfn_reader_user * user,struct iopt_pages * pages)721  static void pfn_reader_user_destroy(struct pfn_reader_user *user,
722  				    struct iopt_pages *pages)
723  {
724  	if (user->locked != -1) {
725  		if (user->locked)
726  			mmap_read_unlock(pages->source_mm);
727  		if (pages->source_mm != current->mm)
728  			mmput(pages->source_mm);
729  		user->locked = -1;
730  	}
731  
732  	kfree(user->upages);
733  	user->upages = NULL;
734  }
735  
pfn_reader_user_pin(struct pfn_reader_user * user,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)736  static int pfn_reader_user_pin(struct pfn_reader_user *user,
737  			       struct iopt_pages *pages,
738  			       unsigned long start_index,
739  			       unsigned long last_index)
740  {
741  	bool remote_mm = pages->source_mm != current->mm;
742  	unsigned long npages;
743  	uintptr_t uptr;
744  	long rc;
745  
746  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
747  	    WARN_ON(last_index < start_index))
748  		return -EINVAL;
749  
750  	if (!user->upages) {
751  		/* All undone in pfn_reader_destroy() */
752  		user->upages_len =
753  			(last_index - start_index + 1) * sizeof(*user->upages);
754  		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
755  		if (!user->upages)
756  			return -ENOMEM;
757  	}
758  
759  	if (user->locked == -1) {
760  		/*
761  		 * The majority of usages will run the map task within the mm
762  		 * providing the pages, so we can optimize into
763  		 * get_user_pages_fast()
764  		 */
765  		if (remote_mm) {
766  			if (!mmget_not_zero(pages->source_mm))
767  				return -EFAULT;
768  		}
769  		user->locked = 0;
770  	}
771  
772  	npages = min_t(unsigned long, last_index - start_index + 1,
773  		       user->upages_len / sizeof(*user->upages));
774  
775  
776  	if (iommufd_should_fail())
777  		return -EFAULT;
778  
779  	uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
780  	if (!remote_mm)
781  		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
782  					 user->upages);
783  	else {
784  		if (!user->locked) {
785  			mmap_read_lock(pages->source_mm);
786  			user->locked = 1;
787  		}
788  		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
789  					   user->gup_flags, user->upages,
790  					   &user->locked);
791  	}
792  	if (rc <= 0) {
793  		if (WARN_ON(!rc))
794  			return -EFAULT;
795  		return rc;
796  	}
797  	iopt_pages_add_npinned(pages, rc);
798  	user->upages_start = start_index;
799  	user->upages_end = start_index + rc;
800  	return 0;
801  }
802  
803  /* This is the "modern" and faster accounting method used by io_uring */
incr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)804  static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
805  {
806  	unsigned long lock_limit;
807  	unsigned long cur_pages;
808  	unsigned long new_pages;
809  
810  	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
811  		     PAGE_SHIFT;
812  
813  	cur_pages = atomic_long_read(&pages->source_user->locked_vm);
814  	do {
815  		new_pages = cur_pages + npages;
816  		if (new_pages > lock_limit)
817  			return -ENOMEM;
818  	} while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
819  					  &cur_pages, new_pages));
820  	return 0;
821  }
822  
decr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)823  static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
824  {
825  	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
826  		return;
827  	atomic_long_sub(npages, &pages->source_user->locked_vm);
828  }
829  
830  /* This is the accounting method used for compatibility with VFIO */
update_mm_locked_vm(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)831  static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
832  			       bool inc, struct pfn_reader_user *user)
833  {
834  	bool do_put = false;
835  	int rc;
836  
837  	if (user && user->locked) {
838  		mmap_read_unlock(pages->source_mm);
839  		user->locked = 0;
840  		/* If we had the lock then we also have a get */
841  	} else if ((!user || !user->upages) &&
842  		   pages->source_mm != current->mm) {
843  		if (!mmget_not_zero(pages->source_mm))
844  			return -EINVAL;
845  		do_put = true;
846  	}
847  
848  	mmap_write_lock(pages->source_mm);
849  	rc = __account_locked_vm(pages->source_mm, npages, inc,
850  				 pages->source_task, false);
851  	mmap_write_unlock(pages->source_mm);
852  
853  	if (do_put)
854  		mmput(pages->source_mm);
855  	return rc;
856  }
857  
do_update_pinned(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)858  static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
859  			    bool inc, struct pfn_reader_user *user)
860  {
861  	int rc = 0;
862  
863  	switch (pages->account_mode) {
864  	case IOPT_PAGES_ACCOUNT_NONE:
865  		break;
866  	case IOPT_PAGES_ACCOUNT_USER:
867  		if (inc)
868  			rc = incr_user_locked_vm(pages, npages);
869  		else
870  			decr_user_locked_vm(pages, npages);
871  		break;
872  	case IOPT_PAGES_ACCOUNT_MM:
873  		rc = update_mm_locked_vm(pages, npages, inc, user);
874  		break;
875  	}
876  	if (rc)
877  		return rc;
878  
879  	pages->last_npinned = pages->npinned;
880  	if (inc)
881  		atomic64_add(npages, &pages->source_mm->pinned_vm);
882  	else
883  		atomic64_sub(npages, &pages->source_mm->pinned_vm);
884  	return 0;
885  }
886  
update_unpinned(struct iopt_pages * pages)887  static void update_unpinned(struct iopt_pages *pages)
888  {
889  	if (WARN_ON(pages->npinned > pages->last_npinned))
890  		return;
891  	if (pages->npinned == pages->last_npinned)
892  		return;
893  	do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
894  			 NULL);
895  }
896  
897  /*
898   * Changes in the number of pages pinned is done after the pages have been read
899   * and processed. If the user lacked the limit then the error unwind will unpin
900   * everything that was just pinned. This is because it is expensive to calculate
901   * how many pages we have already pinned within a range to generate an accurate
902   * prediction in advance of doing the work to actually pin them.
903   */
pfn_reader_user_update_pinned(struct pfn_reader_user * user,struct iopt_pages * pages)904  static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
905  					 struct iopt_pages *pages)
906  {
907  	unsigned long npages;
908  	bool inc;
909  
910  	lockdep_assert_held(&pages->mutex);
911  
912  	if (pages->npinned == pages->last_npinned)
913  		return 0;
914  
915  	if (pages->npinned < pages->last_npinned) {
916  		npages = pages->last_npinned - pages->npinned;
917  		inc = false;
918  	} else {
919  		if (iommufd_should_fail())
920  			return -ENOMEM;
921  		npages = pages->npinned - pages->last_npinned;
922  		inc = true;
923  	}
924  	return do_update_pinned(pages, npages, inc, user);
925  }
926  
927  /*
928   * PFNs are stored in three places, in order of preference:
929   * - The iopt_pages xarray. This is only populated if there is a
930   *   iopt_pages_access
931   * - The iommu_domain under an area
932   * - The original PFN source, ie pages->source_mm
933   *
934   * This iterator reads the pfns optimizing to load according to the
935   * above order.
936   */
937  struct pfn_reader {
938  	struct iopt_pages *pages;
939  	struct interval_tree_double_span_iter span;
940  	struct pfn_batch batch;
941  	unsigned long batch_start_index;
942  	unsigned long batch_end_index;
943  	unsigned long last_index;
944  
945  	struct pfn_reader_user user;
946  };
947  
pfn_reader_update_pinned(struct pfn_reader * pfns)948  static int pfn_reader_update_pinned(struct pfn_reader *pfns)
949  {
950  	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
951  }
952  
953  /*
954   * The batch can contain a mixture of pages that are still in use and pages that
955   * need to be unpinned. Unpin only pages that are not held anywhere else.
956   */
pfn_reader_unpin(struct pfn_reader * pfns)957  static void pfn_reader_unpin(struct pfn_reader *pfns)
958  {
959  	unsigned long last = pfns->batch_end_index - 1;
960  	unsigned long start = pfns->batch_start_index;
961  	struct interval_tree_double_span_iter span;
962  	struct iopt_pages *pages = pfns->pages;
963  
964  	lockdep_assert_held(&pages->mutex);
965  
966  	interval_tree_for_each_double_span(&span, &pages->access_itree,
967  					   &pages->domains_itree, start, last) {
968  		if (span.is_used)
969  			continue;
970  
971  		batch_unpin(&pfns->batch, pages, span.start_hole - start,
972  			    span.last_hole - span.start_hole + 1);
973  	}
974  }
975  
976  /* Process a single span to load it from the proper storage */
pfn_reader_fill_span(struct pfn_reader * pfns)977  static int pfn_reader_fill_span(struct pfn_reader *pfns)
978  {
979  	struct interval_tree_double_span_iter *span = &pfns->span;
980  	unsigned long start_index = pfns->batch_end_index;
981  	struct iopt_area *area;
982  	int rc;
983  
984  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
985  	    WARN_ON(span->last_used < start_index))
986  		return -EINVAL;
987  
988  	if (span->is_used == 1) {
989  		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
990  				  start_index, span->last_used);
991  		return 0;
992  	}
993  
994  	if (span->is_used == 2) {
995  		/*
996  		 * Pull as many pages from the first domain we find in the
997  		 * target span. If it is too small then we will be called again
998  		 * and we'll find another area.
999  		 */
1000  		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1001  		if (WARN_ON(!area))
1002  			return -EINVAL;
1003  
1004  		/* The storage_domain cannot change without the pages mutex */
1005  		batch_from_domain(
1006  			&pfns->batch, area->storage_domain, area, start_index,
1007  			min(iopt_area_last_index(area), span->last_used));
1008  		return 0;
1009  	}
1010  
1011  	if (start_index >= pfns->user.upages_end) {
1012  		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1013  					 span->last_hole);
1014  		if (rc)
1015  			return rc;
1016  	}
1017  
1018  	batch_from_pages(&pfns->batch,
1019  			 pfns->user.upages +
1020  				 (start_index - pfns->user.upages_start),
1021  			 pfns->user.upages_end - start_index);
1022  	return 0;
1023  }
1024  
pfn_reader_done(struct pfn_reader * pfns)1025  static bool pfn_reader_done(struct pfn_reader *pfns)
1026  {
1027  	return pfns->batch_start_index == pfns->last_index + 1;
1028  }
1029  
pfn_reader_next(struct pfn_reader * pfns)1030  static int pfn_reader_next(struct pfn_reader *pfns)
1031  {
1032  	int rc;
1033  
1034  	batch_clear(&pfns->batch);
1035  	pfns->batch_start_index = pfns->batch_end_index;
1036  
1037  	while (pfns->batch_end_index != pfns->last_index + 1) {
1038  		unsigned int npfns = pfns->batch.total_pfns;
1039  
1040  		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1041  		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1042  			return -EINVAL;
1043  
1044  		rc = pfn_reader_fill_span(pfns);
1045  		if (rc)
1046  			return rc;
1047  
1048  		if (WARN_ON(!pfns->batch.total_pfns))
1049  			return -EINVAL;
1050  
1051  		pfns->batch_end_index =
1052  			pfns->batch_start_index + pfns->batch.total_pfns;
1053  		if (pfns->batch_end_index == pfns->span.last_used + 1)
1054  			interval_tree_double_span_iter_next(&pfns->span);
1055  
1056  		/* Batch is full */
1057  		if (npfns == pfns->batch.total_pfns)
1058  			return 0;
1059  	}
1060  	return 0;
1061  }
1062  
pfn_reader_init(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1063  static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1064  			   unsigned long start_index, unsigned long last_index)
1065  {
1066  	int rc;
1067  
1068  	lockdep_assert_held(&pages->mutex);
1069  
1070  	pfns->pages = pages;
1071  	pfns->batch_start_index = start_index;
1072  	pfns->batch_end_index = start_index;
1073  	pfns->last_index = last_index;
1074  	pfn_reader_user_init(&pfns->user, pages);
1075  	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1076  	if (rc)
1077  		return rc;
1078  	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1079  					     &pages->domains_itree, start_index,
1080  					     last_index);
1081  	return 0;
1082  }
1083  
1084  /*
1085   * There are many assertions regarding the state of pages->npinned vs
1086   * pages->last_pinned, for instance something like unmapping a domain must only
1087   * decrement the npinned, and pfn_reader_destroy() must be called only after all
1088   * the pins are updated. This is fine for success flows, but error flows
1089   * sometimes need to release the pins held inside the pfn_reader before going on
1090   * to complete unmapping and releasing pins held in domains.
1091   */
pfn_reader_release_pins(struct pfn_reader * pfns)1092  static void pfn_reader_release_pins(struct pfn_reader *pfns)
1093  {
1094  	struct iopt_pages *pages = pfns->pages;
1095  
1096  	if (pfns->user.upages_end > pfns->batch_end_index) {
1097  		size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1098  
1099  		/* Any pages not transferred to the batch are just unpinned */
1100  		unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1101  						      pfns->user.upages_start),
1102  				 npages);
1103  		iopt_pages_sub_npinned(pages, npages);
1104  		pfns->user.upages_end = pfns->batch_end_index;
1105  	}
1106  	if (pfns->batch_start_index != pfns->batch_end_index) {
1107  		pfn_reader_unpin(pfns);
1108  		pfns->batch_start_index = pfns->batch_end_index;
1109  	}
1110  }
1111  
pfn_reader_destroy(struct pfn_reader * pfns)1112  static void pfn_reader_destroy(struct pfn_reader *pfns)
1113  {
1114  	struct iopt_pages *pages = pfns->pages;
1115  
1116  	pfn_reader_release_pins(pfns);
1117  	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1118  	batch_destroy(&pfns->batch, NULL);
1119  	WARN_ON(pages->last_npinned != pages->npinned);
1120  }
1121  
pfn_reader_first(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1122  static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1123  			    unsigned long start_index, unsigned long last_index)
1124  {
1125  	int rc;
1126  
1127  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1128  	    WARN_ON(last_index < start_index))
1129  		return -EINVAL;
1130  
1131  	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1132  	if (rc)
1133  		return rc;
1134  	rc = pfn_reader_next(pfns);
1135  	if (rc) {
1136  		pfn_reader_destroy(pfns);
1137  		return rc;
1138  	}
1139  	return 0;
1140  }
1141  
iopt_alloc_pages(void __user * uptr,unsigned long length,bool writable)1142  struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1143  				    bool writable)
1144  {
1145  	struct iopt_pages *pages;
1146  	unsigned long end;
1147  
1148  	/*
1149  	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1150  	 * below from overflow
1151  	 */
1152  	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1153  		return ERR_PTR(-EINVAL);
1154  
1155  	if (check_add_overflow((unsigned long)uptr, length, &end))
1156  		return ERR_PTR(-EOVERFLOW);
1157  
1158  	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1159  	if (!pages)
1160  		return ERR_PTR(-ENOMEM);
1161  
1162  	kref_init(&pages->kref);
1163  	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1164  	mutex_init(&pages->mutex);
1165  	pages->source_mm = current->mm;
1166  	mmgrab(pages->source_mm);
1167  	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1168  	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1169  	pages->access_itree = RB_ROOT_CACHED;
1170  	pages->domains_itree = RB_ROOT_CACHED;
1171  	pages->writable = writable;
1172  	if (capable(CAP_IPC_LOCK))
1173  		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1174  	else
1175  		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1176  	pages->source_task = current->group_leader;
1177  	get_task_struct(current->group_leader);
1178  	pages->source_user = get_uid(current_user());
1179  	return pages;
1180  }
1181  
iopt_release_pages(struct kref * kref)1182  void iopt_release_pages(struct kref *kref)
1183  {
1184  	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1185  
1186  	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1187  	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1188  	WARN_ON(pages->npinned);
1189  	WARN_ON(!xa_empty(&pages->pinned_pfns));
1190  	mmdrop(pages->source_mm);
1191  	mutex_destroy(&pages->mutex);
1192  	put_task_struct(pages->source_task);
1193  	free_uid(pages->source_user);
1194  	kfree(pages);
1195  }
1196  
1197  static void
iopt_area_unpin_domain(struct pfn_batch * batch,struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index,unsigned long * unmapped_end_index,unsigned long real_last_index)1198  iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1199  		       struct iopt_pages *pages, struct iommu_domain *domain,
1200  		       unsigned long start_index, unsigned long last_index,
1201  		       unsigned long *unmapped_end_index,
1202  		       unsigned long real_last_index)
1203  {
1204  	while (start_index <= last_index) {
1205  		unsigned long batch_last_index;
1206  
1207  		if (*unmapped_end_index <= last_index) {
1208  			unsigned long start =
1209  				max(start_index, *unmapped_end_index);
1210  
1211  			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1212  			    batch->total_pfns)
1213  				WARN_ON(*unmapped_end_index -
1214  						batch->total_pfns !=
1215  					start_index);
1216  			batch_from_domain(batch, domain, area, start,
1217  					  last_index);
1218  			batch_last_index = start_index + batch->total_pfns - 1;
1219  		} else {
1220  			batch_last_index = last_index;
1221  		}
1222  
1223  		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1224  			WARN_ON(batch_last_index > real_last_index);
1225  
1226  		/*
1227  		 * unmaps must always 'cut' at a place where the pfns are not
1228  		 * contiguous to pair with the maps that always install
1229  		 * contiguous pages. Thus, if we have to stop unpinning in the
1230  		 * middle of the domains we need to keep reading pfns until we
1231  		 * find a cut point to do the unmap. The pfns we read are
1232  		 * carried over and either skipped or integrated into the next
1233  		 * batch.
1234  		 */
1235  		if (batch_last_index == last_index &&
1236  		    last_index != real_last_index)
1237  			batch_from_domain_continue(batch, domain, area,
1238  						   last_index + 1,
1239  						   real_last_index);
1240  
1241  		if (*unmapped_end_index <= batch_last_index) {
1242  			iopt_area_unmap_domain_range(
1243  				area, domain, *unmapped_end_index,
1244  				start_index + batch->total_pfns - 1);
1245  			*unmapped_end_index = start_index + batch->total_pfns;
1246  		}
1247  
1248  		/* unpin must follow unmap */
1249  		batch_unpin(batch, pages, 0,
1250  			    batch_last_index - start_index + 1);
1251  		start_index = batch_last_index + 1;
1252  
1253  		batch_clear_carry(batch,
1254  				  *unmapped_end_index - batch_last_index - 1);
1255  	}
1256  }
1257  
__iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long last_index)1258  static void __iopt_area_unfill_domain(struct iopt_area *area,
1259  				      struct iopt_pages *pages,
1260  				      struct iommu_domain *domain,
1261  				      unsigned long last_index)
1262  {
1263  	struct interval_tree_double_span_iter span;
1264  	unsigned long start_index = iopt_area_index(area);
1265  	unsigned long unmapped_end_index = start_index;
1266  	u64 backup[BATCH_BACKUP_SIZE];
1267  	struct pfn_batch batch;
1268  
1269  	lockdep_assert_held(&pages->mutex);
1270  
1271  	/*
1272  	 * For security we must not unpin something that is still DMA mapped,
1273  	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1274  	 * This creates a complexity where we need to skip over unpinning pages
1275  	 * held in the xarray, but continue to unmap from the domain.
1276  	 *
1277  	 * The domain unmap cannot stop in the middle of a contiguous range of
1278  	 * PFNs. To solve this problem the unpinning step will read ahead to the
1279  	 * end of any contiguous span, unmap that whole span, and then only
1280  	 * unpin the leading part that does not have any accesses. The residual
1281  	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1282  	 * batch as they are moved to the front of the PFN list and continue on
1283  	 * to the next iteration(s).
1284  	 */
1285  	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1286  	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1287  					   &pages->access_itree, start_index,
1288  					   last_index) {
1289  		if (span.is_used) {
1290  			batch_skip_carry(&batch,
1291  					 span.last_used - span.start_used + 1);
1292  			continue;
1293  		}
1294  		iopt_area_unpin_domain(&batch, area, pages, domain,
1295  				       span.start_hole, span.last_hole,
1296  				       &unmapped_end_index, last_index);
1297  	}
1298  	/*
1299  	 * If the range ends in a access then we do the residual unmap without
1300  	 * any unpins.
1301  	 */
1302  	if (unmapped_end_index != last_index + 1)
1303  		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1304  					     last_index);
1305  	WARN_ON(batch.total_pfns);
1306  	batch_destroy(&batch, backup);
1307  	update_unpinned(pages);
1308  }
1309  
iopt_area_unfill_partial_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long end_index)1310  static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1311  					    struct iopt_pages *pages,
1312  					    struct iommu_domain *domain,
1313  					    unsigned long end_index)
1314  {
1315  	if (end_index != iopt_area_index(area))
1316  		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1317  }
1318  
1319  /**
1320   * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1321   * @area: The IOVA range to unmap
1322   * @domain: The domain to unmap
1323   *
1324   * The caller must know that unpinning is not required, usually because there
1325   * are other domains in the iopt.
1326   */
iopt_area_unmap_domain(struct iopt_area * area,struct iommu_domain * domain)1327  void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1328  {
1329  	iommu_unmap_nofail(domain, iopt_area_iova(area),
1330  			   iopt_area_length(area));
1331  }
1332  
1333  /**
1334   * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1335   * @area: IOVA area to use
1336   * @pages: page supplier for the area (area->pages is NULL)
1337   * @domain: Domain to unmap from
1338   *
1339   * The domain should be removed from the domains_itree before calling. The
1340   * domain will always be unmapped, but the PFNs may not be unpinned if there are
1341   * still accesses.
1342   */
iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain)1343  void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1344  			     struct iommu_domain *domain)
1345  {
1346  	__iopt_area_unfill_domain(area, pages, domain,
1347  				  iopt_area_last_index(area));
1348  }
1349  
1350  /**
1351   * iopt_area_fill_domain() - Map PFNs from the area into a domain
1352   * @area: IOVA area to use
1353   * @domain: Domain to load PFNs into
1354   *
1355   * Read the pfns from the area's underlying iopt_pages and map them into the
1356   * given domain. Called when attaching a new domain to an io_pagetable.
1357   */
iopt_area_fill_domain(struct iopt_area * area,struct iommu_domain * domain)1358  int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1359  {
1360  	unsigned long done_end_index;
1361  	struct pfn_reader pfns;
1362  	int rc;
1363  
1364  	lockdep_assert_held(&area->pages->mutex);
1365  
1366  	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1367  			      iopt_area_last_index(area));
1368  	if (rc)
1369  		return rc;
1370  
1371  	while (!pfn_reader_done(&pfns)) {
1372  		done_end_index = pfns.batch_start_index;
1373  		rc = batch_to_domain(&pfns.batch, domain, area,
1374  				     pfns.batch_start_index);
1375  		if (rc)
1376  			goto out_unmap;
1377  		done_end_index = pfns.batch_end_index;
1378  
1379  		rc = pfn_reader_next(&pfns);
1380  		if (rc)
1381  			goto out_unmap;
1382  	}
1383  
1384  	rc = pfn_reader_update_pinned(&pfns);
1385  	if (rc)
1386  		goto out_unmap;
1387  	goto out_destroy;
1388  
1389  out_unmap:
1390  	pfn_reader_release_pins(&pfns);
1391  	iopt_area_unfill_partial_domain(area, area->pages, domain,
1392  					done_end_index);
1393  out_destroy:
1394  	pfn_reader_destroy(&pfns);
1395  	return rc;
1396  }
1397  
1398  /**
1399   * iopt_area_fill_domains() - Install PFNs into the area's domains
1400   * @area: The area to act on
1401   * @pages: The pages associated with the area (area->pages is NULL)
1402   *
1403   * Called during area creation. The area is freshly created and not inserted in
1404   * the domains_itree yet. PFNs are read and loaded into every domain held in the
1405   * area's io_pagetable and the area is installed in the domains_itree.
1406   *
1407   * On failure all domains are left unchanged.
1408   */
iopt_area_fill_domains(struct iopt_area * area,struct iopt_pages * pages)1409  int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1410  {
1411  	unsigned long done_first_end_index;
1412  	unsigned long done_all_end_index;
1413  	struct iommu_domain *domain;
1414  	unsigned long unmap_index;
1415  	struct pfn_reader pfns;
1416  	unsigned long index;
1417  	int rc;
1418  
1419  	lockdep_assert_held(&area->iopt->domains_rwsem);
1420  
1421  	if (xa_empty(&area->iopt->domains))
1422  		return 0;
1423  
1424  	mutex_lock(&pages->mutex);
1425  	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1426  			      iopt_area_last_index(area));
1427  	if (rc)
1428  		goto out_unlock;
1429  
1430  	while (!pfn_reader_done(&pfns)) {
1431  		done_first_end_index = pfns.batch_end_index;
1432  		done_all_end_index = pfns.batch_start_index;
1433  		xa_for_each(&area->iopt->domains, index, domain) {
1434  			rc = batch_to_domain(&pfns.batch, domain, area,
1435  					     pfns.batch_start_index);
1436  			if (rc)
1437  				goto out_unmap;
1438  		}
1439  		done_all_end_index = done_first_end_index;
1440  
1441  		rc = pfn_reader_next(&pfns);
1442  		if (rc)
1443  			goto out_unmap;
1444  	}
1445  	rc = pfn_reader_update_pinned(&pfns);
1446  	if (rc)
1447  		goto out_unmap;
1448  
1449  	area->storage_domain = xa_load(&area->iopt->domains, 0);
1450  	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1451  	goto out_destroy;
1452  
1453  out_unmap:
1454  	pfn_reader_release_pins(&pfns);
1455  	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1456  		unsigned long end_index;
1457  
1458  		if (unmap_index < index)
1459  			end_index = done_first_end_index;
1460  		else
1461  			end_index = done_all_end_index;
1462  
1463  		/*
1464  		 * The area is not yet part of the domains_itree so we have to
1465  		 * manage the unpinning specially. The last domain does the
1466  		 * unpin, every other domain is just unmapped.
1467  		 */
1468  		if (unmap_index != area->iopt->next_domain_id - 1) {
1469  			if (end_index != iopt_area_index(area))
1470  				iopt_area_unmap_domain_range(
1471  					area, domain, iopt_area_index(area),
1472  					end_index - 1);
1473  		} else {
1474  			iopt_area_unfill_partial_domain(area, pages, domain,
1475  							end_index);
1476  		}
1477  	}
1478  out_destroy:
1479  	pfn_reader_destroy(&pfns);
1480  out_unlock:
1481  	mutex_unlock(&pages->mutex);
1482  	return rc;
1483  }
1484  
1485  /**
1486   * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1487   * @area: The area to act on
1488   * @pages: The pages associated with the area (area->pages is NULL)
1489   *
1490   * Called during area destruction. This unmaps the iova's covered by all the
1491   * area's domains and releases the PFNs.
1492   */
iopt_area_unfill_domains(struct iopt_area * area,struct iopt_pages * pages)1493  void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1494  {
1495  	struct io_pagetable *iopt = area->iopt;
1496  	struct iommu_domain *domain;
1497  	unsigned long index;
1498  
1499  	lockdep_assert_held(&iopt->domains_rwsem);
1500  
1501  	mutex_lock(&pages->mutex);
1502  	if (!area->storage_domain)
1503  		goto out_unlock;
1504  
1505  	xa_for_each(&iopt->domains, index, domain)
1506  		if (domain != area->storage_domain)
1507  			iopt_area_unmap_domain_range(
1508  				area, domain, iopt_area_index(area),
1509  				iopt_area_last_index(area));
1510  
1511  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1512  		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1513  	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1514  	iopt_area_unfill_domain(area, pages, area->storage_domain);
1515  	area->storage_domain = NULL;
1516  out_unlock:
1517  	mutex_unlock(&pages->mutex);
1518  }
1519  
iopt_pages_unpin_xarray(struct pfn_batch * batch,struct iopt_pages * pages,unsigned long start_index,unsigned long end_index)1520  static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1521  				    struct iopt_pages *pages,
1522  				    unsigned long start_index,
1523  				    unsigned long end_index)
1524  {
1525  	while (start_index <= end_index) {
1526  		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1527  					end_index);
1528  		batch_unpin(batch, pages, 0, batch->total_pfns);
1529  		start_index += batch->total_pfns;
1530  		batch_clear(batch);
1531  	}
1532  }
1533  
1534  /**
1535   * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1536   * @pages: The pages to act on
1537   * @start_index: Starting PFN index
1538   * @last_index: Last PFN index
1539   *
1540   * Called when an iopt_pages_access is removed, removes pages from the itree.
1541   * The access should already be removed from the access_itree.
1542   */
iopt_pages_unfill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1543  void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1544  			      unsigned long start_index,
1545  			      unsigned long last_index)
1546  {
1547  	struct interval_tree_double_span_iter span;
1548  	u64 backup[BATCH_BACKUP_SIZE];
1549  	struct pfn_batch batch;
1550  	bool batch_inited = false;
1551  
1552  	lockdep_assert_held(&pages->mutex);
1553  
1554  	interval_tree_for_each_double_span(&span, &pages->access_itree,
1555  					   &pages->domains_itree, start_index,
1556  					   last_index) {
1557  		if (!span.is_used) {
1558  			if (!batch_inited) {
1559  				batch_init_backup(&batch,
1560  						  last_index - start_index + 1,
1561  						  backup, sizeof(backup));
1562  				batch_inited = true;
1563  			}
1564  			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1565  						span.last_hole);
1566  		} else if (span.is_used == 2) {
1567  			/* Covered by a domain */
1568  			clear_xarray(&pages->pinned_pfns, span.start_used,
1569  				     span.last_used);
1570  		}
1571  		/* Otherwise covered by an existing access */
1572  	}
1573  	if (batch_inited)
1574  		batch_destroy(&batch, backup);
1575  	update_unpinned(pages);
1576  }
1577  
1578  /**
1579   * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1580   * @pages: The pages to act on
1581   * @start_index: The first page index in the range
1582   * @last_index: The last page index in the range
1583   * @out_pages: The output array to return the pages
1584   *
1585   * This can be called if the caller is holding a refcount on an
1586   * iopt_pages_access that is known to have already been filled. It quickly reads
1587   * the pages directly from the xarray.
1588   *
1589   * This is part of the SW iommu interface to read pages for in-kernel use.
1590   */
iopt_pages_fill_from_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1591  void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1592  				 unsigned long start_index,
1593  				 unsigned long last_index,
1594  				 struct page **out_pages)
1595  {
1596  	XA_STATE(xas, &pages->pinned_pfns, start_index);
1597  	void *entry;
1598  
1599  	rcu_read_lock();
1600  	while (start_index <= last_index) {
1601  		entry = xas_next(&xas);
1602  		if (xas_retry(&xas, entry))
1603  			continue;
1604  		WARN_ON(!xa_is_value(entry));
1605  		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1606  		start_index++;
1607  	}
1608  	rcu_read_unlock();
1609  }
1610  
iopt_pages_fill_from_domain(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1611  static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1612  				       unsigned long start_index,
1613  				       unsigned long last_index,
1614  				       struct page **out_pages)
1615  {
1616  	while (start_index != last_index + 1) {
1617  		unsigned long domain_last;
1618  		struct iopt_area *area;
1619  
1620  		area = iopt_pages_find_domain_area(pages, start_index);
1621  		if (WARN_ON(!area))
1622  			return -EINVAL;
1623  
1624  		domain_last = min(iopt_area_last_index(area), last_index);
1625  		out_pages = raw_pages_from_domain(area->storage_domain, area,
1626  						  start_index, domain_last,
1627  						  out_pages);
1628  		start_index = domain_last + 1;
1629  	}
1630  	return 0;
1631  }
1632  
iopt_pages_fill_from_mm(struct iopt_pages * pages,struct pfn_reader_user * user,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1633  static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1634  				   struct pfn_reader_user *user,
1635  				   unsigned long start_index,
1636  				   unsigned long last_index,
1637  				   struct page **out_pages)
1638  {
1639  	unsigned long cur_index = start_index;
1640  	int rc;
1641  
1642  	while (cur_index != last_index + 1) {
1643  		user->upages = out_pages + (cur_index - start_index);
1644  		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1645  		if (rc)
1646  			goto out_unpin;
1647  		cur_index = user->upages_end;
1648  	}
1649  	return 0;
1650  
1651  out_unpin:
1652  	if (start_index != cur_index)
1653  		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1654  				     out_pages);
1655  	return rc;
1656  }
1657  
1658  /**
1659   * iopt_pages_fill_xarray() - Read PFNs
1660   * @pages: The pages to act on
1661   * @start_index: The first page index in the range
1662   * @last_index: The last page index in the range
1663   * @out_pages: The output array to return the pages, may be NULL
1664   *
1665   * This populates the xarray and returns the pages in out_pages. As the slow
1666   * path this is able to copy pages from other storage tiers into the xarray.
1667   *
1668   * On failure the xarray is left unchanged.
1669   *
1670   * This is part of the SW iommu interface to read pages for in-kernel use.
1671   */
iopt_pages_fill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1672  int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1673  			   unsigned long last_index, struct page **out_pages)
1674  {
1675  	struct interval_tree_double_span_iter span;
1676  	unsigned long xa_end = start_index;
1677  	struct pfn_reader_user user;
1678  	int rc;
1679  
1680  	lockdep_assert_held(&pages->mutex);
1681  
1682  	pfn_reader_user_init(&user, pages);
1683  	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1684  	interval_tree_for_each_double_span(&span, &pages->access_itree,
1685  					   &pages->domains_itree, start_index,
1686  					   last_index) {
1687  		struct page **cur_pages;
1688  
1689  		if (span.is_used == 1) {
1690  			cur_pages = out_pages + (span.start_used - start_index);
1691  			iopt_pages_fill_from_xarray(pages, span.start_used,
1692  						    span.last_used, cur_pages);
1693  			continue;
1694  		}
1695  
1696  		if (span.is_used == 2) {
1697  			cur_pages = out_pages + (span.start_used - start_index);
1698  			iopt_pages_fill_from_domain(pages, span.start_used,
1699  						    span.last_used, cur_pages);
1700  			rc = pages_to_xarray(&pages->pinned_pfns,
1701  					     span.start_used, span.last_used,
1702  					     cur_pages);
1703  			if (rc)
1704  				goto out_clean_xa;
1705  			xa_end = span.last_used + 1;
1706  			continue;
1707  		}
1708  
1709  		/* hole */
1710  		cur_pages = out_pages + (span.start_hole - start_index);
1711  		rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1712  					     span.last_hole, cur_pages);
1713  		if (rc)
1714  			goto out_clean_xa;
1715  		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1716  				     span.last_hole, cur_pages);
1717  		if (rc) {
1718  			iopt_pages_err_unpin(pages, span.start_hole,
1719  					     span.last_hole, cur_pages);
1720  			goto out_clean_xa;
1721  		}
1722  		xa_end = span.last_hole + 1;
1723  	}
1724  	rc = pfn_reader_user_update_pinned(&user, pages);
1725  	if (rc)
1726  		goto out_clean_xa;
1727  	user.upages = NULL;
1728  	pfn_reader_user_destroy(&user, pages);
1729  	return 0;
1730  
1731  out_clean_xa:
1732  	if (start_index != xa_end)
1733  		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1734  	user.upages = NULL;
1735  	pfn_reader_user_destroy(&user, pages);
1736  	return rc;
1737  }
1738  
1739  /*
1740   * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1741   * do every scenario and is fully consistent with what an iommu_domain would
1742   * see.
1743   */
iopt_pages_rw_slow(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1744  static int iopt_pages_rw_slow(struct iopt_pages *pages,
1745  			      unsigned long start_index,
1746  			      unsigned long last_index, unsigned long offset,
1747  			      void *data, unsigned long length,
1748  			      unsigned int flags)
1749  {
1750  	struct pfn_reader pfns;
1751  	int rc;
1752  
1753  	mutex_lock(&pages->mutex);
1754  
1755  	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1756  	if (rc)
1757  		goto out_unlock;
1758  
1759  	while (!pfn_reader_done(&pfns)) {
1760  		unsigned long done;
1761  
1762  		done = batch_rw(&pfns.batch, data, offset, length, flags);
1763  		data += done;
1764  		length -= done;
1765  		offset = 0;
1766  		pfn_reader_unpin(&pfns);
1767  
1768  		rc = pfn_reader_next(&pfns);
1769  		if (rc)
1770  			goto out_destroy;
1771  	}
1772  	if (WARN_ON(length != 0))
1773  		rc = -EINVAL;
1774  out_destroy:
1775  	pfn_reader_destroy(&pfns);
1776  out_unlock:
1777  	mutex_unlock(&pages->mutex);
1778  	return rc;
1779  }
1780  
1781  /*
1782   * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1783   * memory allocations or interval tree searches.
1784   */
iopt_pages_rw_page(struct iopt_pages * pages,unsigned long index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1785  static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1786  			      unsigned long offset, void *data,
1787  			      unsigned long length, unsigned int flags)
1788  {
1789  	struct page *page = NULL;
1790  	int rc;
1791  
1792  	if (!mmget_not_zero(pages->source_mm))
1793  		return iopt_pages_rw_slow(pages, index, index, offset, data,
1794  					  length, flags);
1795  
1796  	if (iommufd_should_fail()) {
1797  		rc = -EINVAL;
1798  		goto out_mmput;
1799  	}
1800  
1801  	mmap_read_lock(pages->source_mm);
1802  	rc = pin_user_pages_remote(
1803  		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1804  		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1805  		NULL);
1806  	mmap_read_unlock(pages->source_mm);
1807  	if (rc != 1) {
1808  		if (WARN_ON(rc >= 0))
1809  			rc = -EINVAL;
1810  		goto out_mmput;
1811  	}
1812  	copy_data_page(page, data, offset, length, flags);
1813  	unpin_user_page(page);
1814  	rc = 0;
1815  
1816  out_mmput:
1817  	mmput(pages->source_mm);
1818  	return rc;
1819  }
1820  
1821  /**
1822   * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1823   * @pages: pages to act on
1824   * @start_byte: First byte of pages to copy to/from
1825   * @data: Kernel buffer to get/put the data
1826   * @length: Number of bytes to copy
1827   * @flags: IOMMUFD_ACCESS_RW_* flags
1828   *
1829   * This will find each page in the range, kmap it and then memcpy to/from
1830   * the given kernel buffer.
1831   */
iopt_pages_rw_access(struct iopt_pages * pages,unsigned long start_byte,void * data,unsigned long length,unsigned int flags)1832  int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1833  			 void *data, unsigned long length, unsigned int flags)
1834  {
1835  	unsigned long start_index = start_byte / PAGE_SIZE;
1836  	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1837  	bool change_mm = current->mm != pages->source_mm;
1838  	int rc = 0;
1839  
1840  	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1841  	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1842  		change_mm = true;
1843  
1844  	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1845  		return -EPERM;
1846  
1847  	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1848  		if (start_index == last_index)
1849  			return iopt_pages_rw_page(pages, start_index,
1850  						  start_byte % PAGE_SIZE, data,
1851  						  length, flags);
1852  		return iopt_pages_rw_slow(pages, start_index, last_index,
1853  					  start_byte % PAGE_SIZE, data, length,
1854  					  flags);
1855  	}
1856  
1857  	/*
1858  	 * Try to copy using copy_to_user(). We do this as a fast path and
1859  	 * ignore any pinning inconsistencies, unlike a real DMA path.
1860  	 */
1861  	if (change_mm) {
1862  		if (!mmget_not_zero(pages->source_mm))
1863  			return iopt_pages_rw_slow(pages, start_index,
1864  						  last_index,
1865  						  start_byte % PAGE_SIZE, data,
1866  						  length, flags);
1867  		kthread_use_mm(pages->source_mm);
1868  	}
1869  
1870  	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1871  		if (copy_to_user(pages->uptr + start_byte, data, length))
1872  			rc = -EFAULT;
1873  	} else {
1874  		if (copy_from_user(data, pages->uptr + start_byte, length))
1875  			rc = -EFAULT;
1876  	}
1877  
1878  	if (change_mm) {
1879  		kthread_unuse_mm(pages->source_mm);
1880  		mmput(pages->source_mm);
1881  	}
1882  
1883  	return rc;
1884  }
1885  
1886  static struct iopt_pages_access *
iopt_pages_get_exact_access(struct iopt_pages * pages,unsigned long index,unsigned long last)1887  iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1888  			    unsigned long last)
1889  {
1890  	struct interval_tree_node *node;
1891  
1892  	lockdep_assert_held(&pages->mutex);
1893  
1894  	/* There can be overlapping ranges in this interval tree */
1895  	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1896  	     node; node = interval_tree_iter_next(node, index, last))
1897  		if (node->start == index && node->last == last)
1898  			return container_of(node, struct iopt_pages_access,
1899  					    node);
1900  	return NULL;
1901  }
1902  
1903  /**
1904   * iopt_area_add_access() - Record an in-knerel access for PFNs
1905   * @area: The source of PFNs
1906   * @start_index: First page index
1907   * @last_index: Inclusive last page index
1908   * @out_pages: Output list of struct page's representing the PFNs
1909   * @flags: IOMMUFD_ACCESS_RW_* flags
1910   *
1911   * Record that an in-kernel access will be accessing the pages, ensure they are
1912   * pinned, and return the PFNs as a simple list of 'struct page *'.
1913   *
1914   * This should be undone through a matching call to iopt_area_remove_access()
1915   */
iopt_area_add_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages,unsigned int flags)1916  int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1917  			  unsigned long last_index, struct page **out_pages,
1918  			  unsigned int flags)
1919  {
1920  	struct iopt_pages *pages = area->pages;
1921  	struct iopt_pages_access *access;
1922  	int rc;
1923  
1924  	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1925  		return -EPERM;
1926  
1927  	mutex_lock(&pages->mutex);
1928  	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1929  	if (access) {
1930  		area->num_accesses++;
1931  		access->users++;
1932  		iopt_pages_fill_from_xarray(pages, start_index, last_index,
1933  					    out_pages);
1934  		mutex_unlock(&pages->mutex);
1935  		return 0;
1936  	}
1937  
1938  	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1939  	if (!access) {
1940  		rc = -ENOMEM;
1941  		goto err_unlock;
1942  	}
1943  
1944  	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1945  	if (rc)
1946  		goto err_free;
1947  
1948  	access->node.start = start_index;
1949  	access->node.last = last_index;
1950  	access->users = 1;
1951  	area->num_accesses++;
1952  	interval_tree_insert(&access->node, &pages->access_itree);
1953  	mutex_unlock(&pages->mutex);
1954  	return 0;
1955  
1956  err_free:
1957  	kfree(access);
1958  err_unlock:
1959  	mutex_unlock(&pages->mutex);
1960  	return rc;
1961  }
1962  
1963  /**
1964   * iopt_area_remove_access() - Release an in-kernel access for PFNs
1965   * @area: The source of PFNs
1966   * @start_index: First page index
1967   * @last_index: Inclusive last page index
1968   *
1969   * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1970   * must stop using the PFNs before calling this.
1971   */
iopt_area_remove_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index)1972  void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1973  			     unsigned long last_index)
1974  {
1975  	struct iopt_pages *pages = area->pages;
1976  	struct iopt_pages_access *access;
1977  
1978  	mutex_lock(&pages->mutex);
1979  	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1980  	if (WARN_ON(!access))
1981  		goto out_unlock;
1982  
1983  	WARN_ON(area->num_accesses == 0 || access->users == 0);
1984  	area->num_accesses--;
1985  	access->users--;
1986  	if (access->users)
1987  		goto out_unlock;
1988  
1989  	interval_tree_remove(&access->node, &pages->access_itree);
1990  	iopt_pages_unfill_xarray(pages, start_index, last_index);
1991  	kfree(access);
1992  out_unlock:
1993  	mutex_unlock(&pages->mutex);
1994  }
1995