1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3 *
4 * The iopt_pages is the center of the storage and motion of PFNs. Each
5 * iopt_pages represents a logical linear array of full PFNs. The array is 0
6 * based and has npages in it. Accessors use 'index' to refer to the entry in
7 * this logical array, regardless of its storage location.
8 *
9 * PFNs are stored in a tiered scheme:
10 *  1) iopt_pages::pinned_pfns xarray
11 *  2) An iommu_domain
12 *  3) The origin of the PFNs, i.e. the userspace pointer
13 *
14 * PFN have to be copied between all combinations of tiers, depending on the
15 * configuration.
16 *
17 * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18 * The storage locations of the PFN's index are tracked in the two interval
19 * trees. If no interval includes the index then it is not pinned.
20 *
21 * If access_itree includes the PFN's index then an in-kernel access has
22 * requested the page. The PFN is stored in the xarray so other requestors can
23 * continue to find it.
24 *
25 * If the domains_itree includes the PFN's index then an iommu_domain is storing
26 * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27 * duplicating storage the xarray is not used if only iommu_domains are using
28 * the PFN's index.
29 *
30 * As a general principle this is designed so that destroy never fails. This
31 * means removing an iommu_domain or releasing a in-kernel access will not fail
32 * due to insufficient memory. In practice this means some cases have to hold
33 * PFNs in the xarray even though they are also being stored in an iommu_domain.
34 *
35 * While the iopt_pages can use an iommu_domain as storage, it does not have an
36 * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37 * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38 * and reference their own slice of the PFN array, with sub page granularity.
39 *
40 * In this file the term 'last' indicates an inclusive and closed interval, eg
41 * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42 * no PFNs.
43 *
44 * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45 * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46 * ULONG_MAX so last_index + 1 cannot overflow.
47 */
48#include <linux/overflow.h>
49#include <linux/slab.h>
50#include <linux/iommu.h>
51#include <linux/sched/mm.h>
52#include <linux/highmem.h>
53#include <linux/kthread.h>
54#include <linux/iommufd.h>
55
56#include "io_pagetable.h"
57#include "double_span.h"
58
59#ifndef CONFIG_IOMMUFD_TEST
60#define TEMP_MEMORY_LIMIT 65536
61#else
62#define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63#endif
64#define BATCH_BACKUP_SIZE 32
65
66/*
67 * More memory makes pin_user_pages() and the batching more efficient, but as
68 * this is only a performance optimization don't try too hard to get it. A 64k
69 * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70 * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71 * stack memory as a backup contingency. If backup_len is given this cannot
72 * fail.
73 */
74static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75{
76	void *res;
77
78	if (WARN_ON(*size == 0))
79		return NULL;
80
81	if (*size < backup_len)
82		return backup;
83
84	if (!backup && iommufd_should_fail())
85		return NULL;
86
87	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
88	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
89	if (res)
90		return res;
91	*size = PAGE_SIZE;
92	if (backup_len) {
93		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
94		if (res)
95			return res;
96		*size = backup_len;
97		return backup;
98	}
99	return kmalloc(*size, GFP_KERNEL);
100}
101
102void interval_tree_double_span_iter_update(
103	struct interval_tree_double_span_iter *iter)
104{
105	unsigned long last_hole = ULONG_MAX;
106	unsigned int i;
107
108	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
109		if (interval_tree_span_iter_done(&iter->spans[i])) {
110			iter->is_used = -1;
111			return;
112		}
113
114		if (iter->spans[i].is_hole) {
115			last_hole = min(last_hole, iter->spans[i].last_hole);
116			continue;
117		}
118
119		iter->is_used = i + 1;
120		iter->start_used = iter->spans[i].start_used;
121		iter->last_used = min(iter->spans[i].last_used, last_hole);
122		return;
123	}
124
125	iter->is_used = 0;
126	iter->start_hole = iter->spans[0].start_hole;
127	iter->last_hole =
128		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
129}
130
131void interval_tree_double_span_iter_first(
132	struct interval_tree_double_span_iter *iter,
133	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
134	unsigned long first_index, unsigned long last_index)
135{
136	unsigned int i;
137
138	iter->itrees[0] = itree1;
139	iter->itrees[1] = itree2;
140	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
141		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
142					      first_index, last_index);
143	interval_tree_double_span_iter_update(iter);
144}
145
146void interval_tree_double_span_iter_next(
147	struct interval_tree_double_span_iter *iter)
148{
149	unsigned int i;
150
151	if (iter->is_used == -1 ||
152	    iter->last_hole == iter->spans[0].last_index) {
153		iter->is_used = -1;
154		return;
155	}
156
157	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
158		interval_tree_span_iter_advance(
159			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
160	interval_tree_double_span_iter_update(iter);
161}
162
163static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
164{
165	int rc;
166
167	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
168	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
169		WARN_ON(rc || pages->npinned > pages->npages);
170}
171
172static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
173{
174	int rc;
175
176	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
177	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
178		WARN_ON(rc || pages->npinned > pages->npages);
179}
180
181static void iopt_pages_err_unpin(struct iopt_pages *pages,
182				 unsigned long start_index,
183				 unsigned long last_index,
184				 struct page **page_list)
185{
186	unsigned long npages = last_index - start_index + 1;
187
188	unpin_user_pages(page_list, npages);
189	iopt_pages_sub_npinned(pages, npages);
190}
191
192/*
193 * index is the number of PAGE_SIZE units from the start of the area's
194 * iopt_pages. If the iova is sub page-size then the area has an iova that
195 * covers a portion of the first and last pages in the range.
196 */
197static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
198					     unsigned long index)
199{
200	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
201		WARN_ON(index < iopt_area_index(area) ||
202			index > iopt_area_last_index(area));
203	index -= iopt_area_index(area);
204	if (index == 0)
205		return iopt_area_iova(area);
206	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
207}
208
209static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
210						  unsigned long index)
211{
212	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
213		WARN_ON(index < iopt_area_index(area) ||
214			index > iopt_area_last_index(area));
215	if (index == iopt_area_last_index(area))
216		return iopt_area_last_iova(area);
217	return iopt_area_iova(area) - area->page_offset +
218	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
219}
220
221static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
222			       size_t size)
223{
224	size_t ret;
225
226	ret = iommu_unmap(domain, iova, size);
227	/*
228	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
229	 * something other than exactly as requested. This implies that the
230	 * iommu driver may not fail unmap for reasons beyond bad agruments.
231	 * Particularly, the iommu driver may not do a memory allocation on the
232	 * unmap path.
233	 */
234	WARN_ON(ret != size);
235}
236
237static void iopt_area_unmap_domain_range(struct iopt_area *area,
238					 struct iommu_domain *domain,
239					 unsigned long start_index,
240					 unsigned long last_index)
241{
242	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
243
244	iommu_unmap_nofail(domain, start_iova,
245			   iopt_area_index_to_iova_last(area, last_index) -
246				   start_iova + 1);
247}
248
249static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
250						     unsigned long index)
251{
252	struct interval_tree_node *node;
253
254	node = interval_tree_iter_first(&pages->domains_itree, index, index);
255	if (!node)
256		return NULL;
257	return container_of(node, struct iopt_area, pages_node);
258}
259
260/*
261 * A simple datastructure to hold a vector of PFNs, optimized for contiguous
262 * PFNs. This is used as a temporary holding memory for shuttling pfns from one
263 * place to another. Generally everything is made more efficient if operations
264 * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
265 * better cache locality, etc
266 */
267struct pfn_batch {
268	unsigned long *pfns;
269	u32 *npfns;
270	unsigned int array_size;
271	unsigned int end;
272	unsigned int total_pfns;
273};
274
275static void batch_clear(struct pfn_batch *batch)
276{
277	batch->total_pfns = 0;
278	batch->end = 0;
279	batch->pfns[0] = 0;
280	batch->npfns[0] = 0;
281}
282
283/*
284 * Carry means we carry a portion of the final hugepage over to the front of the
285 * batch
286 */
287static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
288{
289	if (!keep_pfns)
290		return batch_clear(batch);
291
292	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
293		WARN_ON(!batch->end ||
294			batch->npfns[batch->end - 1] < keep_pfns);
295
296	batch->total_pfns = keep_pfns;
297	batch->pfns[0] = batch->pfns[batch->end - 1] +
298			 (batch->npfns[batch->end - 1] - keep_pfns);
299	batch->npfns[0] = keep_pfns;
300	batch->end = 1;
301}
302
303static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
304{
305	if (!batch->total_pfns)
306		return;
307	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
308		WARN_ON(batch->total_pfns != batch->npfns[0]);
309	skip_pfns = min(batch->total_pfns, skip_pfns);
310	batch->pfns[0] += skip_pfns;
311	batch->npfns[0] -= skip_pfns;
312	batch->total_pfns -= skip_pfns;
313}
314
315static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
316			size_t backup_len)
317{
318	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
319	size_t size = max_pages * elmsz;
320
321	batch->pfns = temp_kmalloc(&size, backup, backup_len);
322	if (!batch->pfns)
323		return -ENOMEM;
324	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
325		return -EINVAL;
326	batch->array_size = size / elmsz;
327	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
328	batch_clear(batch);
329	return 0;
330}
331
332static int batch_init(struct pfn_batch *batch, size_t max_pages)
333{
334	return __batch_init(batch, max_pages, NULL, 0);
335}
336
337static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
338			      void *backup, size_t backup_len)
339{
340	__batch_init(batch, max_pages, backup, backup_len);
341}
342
343static void batch_destroy(struct pfn_batch *batch, void *backup)
344{
345	if (batch->pfns != backup)
346		kfree(batch->pfns);
347}
348
349/* true if the pfn was added, false otherwise */
350static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
351{
352	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
353
354	if (batch->end &&
355	    pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
356	    batch->npfns[batch->end - 1] != MAX_NPFNS) {
357		batch->npfns[batch->end - 1]++;
358		batch->total_pfns++;
359		return true;
360	}
361	if (batch->end == batch->array_size)
362		return false;
363	batch->total_pfns++;
364	batch->pfns[batch->end] = pfn;
365	batch->npfns[batch->end] = 1;
366	batch->end++;
367	return true;
368}
369
370/*
371 * Fill the batch with pfns from the domain. When the batch is full, or it
372 * reaches last_index, the function will return. The caller should use
373 * batch->total_pfns to determine the starting point for the next iteration.
374 */
375static void batch_from_domain(struct pfn_batch *batch,
376			      struct iommu_domain *domain,
377			      struct iopt_area *area, unsigned long start_index,
378			      unsigned long last_index)
379{
380	unsigned int page_offset = 0;
381	unsigned long iova;
382	phys_addr_t phys;
383
384	iova = iopt_area_index_to_iova(area, start_index);
385	if (start_index == iopt_area_index(area))
386		page_offset = area->page_offset;
387	while (start_index <= last_index) {
388		/*
389		 * This is pretty slow, it would be nice to get the page size
390		 * back from the driver, or have the driver directly fill the
391		 * batch.
392		 */
393		phys = iommu_iova_to_phys(domain, iova) - page_offset;
394		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
395			return;
396		iova += PAGE_SIZE - page_offset;
397		page_offset = 0;
398		start_index++;
399	}
400}
401
402static struct page **raw_pages_from_domain(struct iommu_domain *domain,
403					   struct iopt_area *area,
404					   unsigned long start_index,
405					   unsigned long last_index,
406					   struct page **out_pages)
407{
408	unsigned int page_offset = 0;
409	unsigned long iova;
410	phys_addr_t phys;
411
412	iova = iopt_area_index_to_iova(area, start_index);
413	if (start_index == iopt_area_index(area))
414		page_offset = area->page_offset;
415	while (start_index <= last_index) {
416		phys = iommu_iova_to_phys(domain, iova) - page_offset;
417		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
418		iova += PAGE_SIZE - page_offset;
419		page_offset = 0;
420		start_index++;
421	}
422	return out_pages;
423}
424
425/* Continues reading a domain until we reach a discontinuity in the pfns. */
426static void batch_from_domain_continue(struct pfn_batch *batch,
427				       struct iommu_domain *domain,
428				       struct iopt_area *area,
429				       unsigned long start_index,
430				       unsigned long last_index)
431{
432	unsigned int array_size = batch->array_size;
433
434	batch->array_size = batch->end;
435	batch_from_domain(batch, domain, area, start_index, last_index);
436	batch->array_size = array_size;
437}
438
439/*
440 * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
441 * mode permits splitting a mapped area up, and then one of the splits is
442 * unmapped. Doing this normally would cause us to violate our invariant of
443 * pairing map/unmap. Thus, to support old VFIO compatibility disable support
444 * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
445 * PAGE_SIZE units, not larger or smaller.
446 */
447static int batch_iommu_map_small(struct iommu_domain *domain,
448				 unsigned long iova, phys_addr_t paddr,
449				 size_t size, int prot)
450{
451	unsigned long start_iova = iova;
452	int rc;
453
454	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
455		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
456			size % PAGE_SIZE);
457
458	while (size) {
459		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
460			       GFP_KERNEL_ACCOUNT);
461		if (rc)
462			goto err_unmap;
463		iova += PAGE_SIZE;
464		paddr += PAGE_SIZE;
465		size -= PAGE_SIZE;
466	}
467	return 0;
468
469err_unmap:
470	if (start_iova != iova)
471		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
472	return rc;
473}
474
475static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
476			   struct iopt_area *area, unsigned long start_index)
477{
478	bool disable_large_pages = area->iopt->disable_large_pages;
479	unsigned long last_iova = iopt_area_last_iova(area);
480	unsigned int page_offset = 0;
481	unsigned long start_iova;
482	unsigned long next_iova;
483	unsigned int cur = 0;
484	unsigned long iova;
485	int rc;
486
487	/* The first index might be a partial page */
488	if (start_index == iopt_area_index(area))
489		page_offset = area->page_offset;
490	next_iova = iova = start_iova =
491		iopt_area_index_to_iova(area, start_index);
492	while (cur < batch->end) {
493		next_iova = min(last_iova + 1,
494				next_iova + batch->npfns[cur] * PAGE_SIZE -
495					page_offset);
496		if (disable_large_pages)
497			rc = batch_iommu_map_small(
498				domain, iova,
499				PFN_PHYS(batch->pfns[cur]) + page_offset,
500				next_iova - iova, area->iommu_prot);
501		else
502			rc = iommu_map(domain, iova,
503				       PFN_PHYS(batch->pfns[cur]) + page_offset,
504				       next_iova - iova, area->iommu_prot,
505				       GFP_KERNEL_ACCOUNT);
506		if (rc)
507			goto err_unmap;
508		iova = next_iova;
509		page_offset = 0;
510		cur++;
511	}
512	return 0;
513err_unmap:
514	if (start_iova != iova)
515		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
516	return rc;
517}
518
519static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
520			      unsigned long start_index,
521			      unsigned long last_index)
522{
523	XA_STATE(xas, xa, start_index);
524	void *entry;
525
526	rcu_read_lock();
527	while (true) {
528		entry = xas_next(&xas);
529		if (xas_retry(&xas, entry))
530			continue;
531		WARN_ON(!xa_is_value(entry));
532		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
533		    start_index == last_index)
534			break;
535		start_index++;
536	}
537	rcu_read_unlock();
538}
539
540static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
541				    unsigned long start_index,
542				    unsigned long last_index)
543{
544	XA_STATE(xas, xa, start_index);
545	void *entry;
546
547	xas_lock(&xas);
548	while (true) {
549		entry = xas_next(&xas);
550		if (xas_retry(&xas, entry))
551			continue;
552		WARN_ON(!xa_is_value(entry));
553		if (!batch_add_pfn(batch, xa_to_value(entry)))
554			break;
555		xas_store(&xas, NULL);
556		if (start_index == last_index)
557			break;
558		start_index++;
559	}
560	xas_unlock(&xas);
561}
562
563static void clear_xarray(struct xarray *xa, unsigned long start_index,
564			 unsigned long last_index)
565{
566	XA_STATE(xas, xa, start_index);
567	void *entry;
568
569	xas_lock(&xas);
570	xas_for_each(&xas, entry, last_index)
571		xas_store(&xas, NULL);
572	xas_unlock(&xas);
573}
574
575static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
576			   unsigned long last_index, struct page **pages)
577{
578	struct page **end_pages = pages + (last_index - start_index) + 1;
579	struct page **half_pages = pages + (end_pages - pages) / 2;
580	XA_STATE(xas, xa, start_index);
581
582	do {
583		void *old;
584
585		xas_lock(&xas);
586		while (pages != end_pages) {
587			/* xarray does not participate in fault injection */
588			if (pages == half_pages && iommufd_should_fail()) {
589				xas_set_err(&xas, -EINVAL);
590				xas_unlock(&xas);
591				/* aka xas_destroy() */
592				xas_nomem(&xas, GFP_KERNEL);
593				goto err_clear;
594			}
595
596			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
597			if (xas_error(&xas))
598				break;
599			WARN_ON(old);
600			pages++;
601			xas_next(&xas);
602		}
603		xas_unlock(&xas);
604	} while (xas_nomem(&xas, GFP_KERNEL));
605
606err_clear:
607	if (xas_error(&xas)) {
608		if (xas.xa_index != start_index)
609			clear_xarray(xa, start_index, xas.xa_index - 1);
610		return xas_error(&xas);
611	}
612	return 0;
613}
614
615static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
616			     size_t npages)
617{
618	struct page **end = pages + npages;
619
620	for (; pages != end; pages++)
621		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
622			break;
623}
624
625static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
626			unsigned int first_page_off, size_t npages)
627{
628	unsigned int cur = 0;
629
630	while (first_page_off) {
631		if (batch->npfns[cur] > first_page_off)
632			break;
633		first_page_off -= batch->npfns[cur];
634		cur++;
635	}
636
637	while (npages) {
638		size_t to_unpin = min_t(size_t, npages,
639					batch->npfns[cur] - first_page_off);
640
641		unpin_user_page_range_dirty_lock(
642			pfn_to_page(batch->pfns[cur] + first_page_off),
643			to_unpin, pages->writable);
644		iopt_pages_sub_npinned(pages, to_unpin);
645		cur++;
646		first_page_off = 0;
647		npages -= to_unpin;
648	}
649}
650
651static void copy_data_page(struct page *page, void *data, unsigned long offset,
652			   size_t length, unsigned int flags)
653{
654	void *mem;
655
656	mem = kmap_local_page(page);
657	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
658		memcpy(mem + offset, data, length);
659		set_page_dirty_lock(page);
660	} else {
661		memcpy(data, mem + offset, length);
662	}
663	kunmap_local(mem);
664}
665
666static unsigned long batch_rw(struct pfn_batch *batch, void *data,
667			      unsigned long offset, unsigned long length,
668			      unsigned int flags)
669{
670	unsigned long copied = 0;
671	unsigned int npage = 0;
672	unsigned int cur = 0;
673
674	while (cur < batch->end) {
675		unsigned long bytes = min(length, PAGE_SIZE - offset);
676
677		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
678			       offset, bytes, flags);
679		offset = 0;
680		length -= bytes;
681		data += bytes;
682		copied += bytes;
683		npage++;
684		if (npage == batch->npfns[cur]) {
685			npage = 0;
686			cur++;
687		}
688		if (!length)
689			break;
690	}
691	return copied;
692}
693
694/* pfn_reader_user is just the pin_user_pages() path */
695struct pfn_reader_user {
696	struct page **upages;
697	size_t upages_len;
698	unsigned long upages_start;
699	unsigned long upages_end;
700	unsigned int gup_flags;
701	/*
702	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
703	 * neither
704	 */
705	int locked;
706};
707
708static void pfn_reader_user_init(struct pfn_reader_user *user,
709				 struct iopt_pages *pages)
710{
711	user->upages = NULL;
712	user->upages_start = 0;
713	user->upages_end = 0;
714	user->locked = -1;
715
716	user->gup_flags = FOLL_LONGTERM;
717	if (pages->writable)
718		user->gup_flags |= FOLL_WRITE;
719}
720
721static void pfn_reader_user_destroy(struct pfn_reader_user *user,
722				    struct iopt_pages *pages)
723{
724	if (user->locked != -1) {
725		if (user->locked)
726			mmap_read_unlock(pages->source_mm);
727		if (pages->source_mm != current->mm)
728			mmput(pages->source_mm);
729		user->locked = -1;
730	}
731
732	kfree(user->upages);
733	user->upages = NULL;
734}
735
736static int pfn_reader_user_pin(struct pfn_reader_user *user,
737			       struct iopt_pages *pages,
738			       unsigned long start_index,
739			       unsigned long last_index)
740{
741	bool remote_mm = pages->source_mm != current->mm;
742	unsigned long npages;
743	uintptr_t uptr;
744	long rc;
745
746	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
747	    WARN_ON(last_index < start_index))
748		return -EINVAL;
749
750	if (!user->upages) {
751		/* All undone in pfn_reader_destroy() */
752		user->upages_len =
753			(last_index - start_index + 1) * sizeof(*user->upages);
754		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
755		if (!user->upages)
756			return -ENOMEM;
757	}
758
759	if (user->locked == -1) {
760		/*
761		 * The majority of usages will run the map task within the mm
762		 * providing the pages, so we can optimize into
763		 * get_user_pages_fast()
764		 */
765		if (remote_mm) {
766			if (!mmget_not_zero(pages->source_mm))
767				return -EFAULT;
768		}
769		user->locked = 0;
770	}
771
772	npages = min_t(unsigned long, last_index - start_index + 1,
773		       user->upages_len / sizeof(*user->upages));
774
775
776	if (iommufd_should_fail())
777		return -EFAULT;
778
779	uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
780	if (!remote_mm)
781		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
782					 user->upages);
783	else {
784		if (!user->locked) {
785			mmap_read_lock(pages->source_mm);
786			user->locked = 1;
787		}
788		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
789					   user->gup_flags, user->upages,
790					   &user->locked);
791	}
792	if (rc <= 0) {
793		if (WARN_ON(!rc))
794			return -EFAULT;
795		return rc;
796	}
797	iopt_pages_add_npinned(pages, rc);
798	user->upages_start = start_index;
799	user->upages_end = start_index + rc;
800	return 0;
801}
802
803/* This is the "modern" and faster accounting method used by io_uring */
804static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
805{
806	unsigned long lock_limit;
807	unsigned long cur_pages;
808	unsigned long new_pages;
809
810	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
811		     PAGE_SHIFT;
812	do {
813		cur_pages = atomic_long_read(&pages->source_user->locked_vm);
814		new_pages = cur_pages + npages;
815		if (new_pages > lock_limit)
816			return -ENOMEM;
817	} while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
818				     new_pages) != cur_pages);
819	return 0;
820}
821
822static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
823{
824	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
825		return;
826	atomic_long_sub(npages, &pages->source_user->locked_vm);
827}
828
829/* This is the accounting method used for compatibility with VFIO */
830static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
831			       bool inc, struct pfn_reader_user *user)
832{
833	bool do_put = false;
834	int rc;
835
836	if (user && user->locked) {
837		mmap_read_unlock(pages->source_mm);
838		user->locked = 0;
839		/* If we had the lock then we also have a get */
840	} else if ((!user || !user->upages) &&
841		   pages->source_mm != current->mm) {
842		if (!mmget_not_zero(pages->source_mm))
843			return -EINVAL;
844		do_put = true;
845	}
846
847	mmap_write_lock(pages->source_mm);
848	rc = __account_locked_vm(pages->source_mm, npages, inc,
849				 pages->source_task, false);
850	mmap_write_unlock(pages->source_mm);
851
852	if (do_put)
853		mmput(pages->source_mm);
854	return rc;
855}
856
857static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
858			    bool inc, struct pfn_reader_user *user)
859{
860	int rc = 0;
861
862	switch (pages->account_mode) {
863	case IOPT_PAGES_ACCOUNT_NONE:
864		break;
865	case IOPT_PAGES_ACCOUNT_USER:
866		if (inc)
867			rc = incr_user_locked_vm(pages, npages);
868		else
869			decr_user_locked_vm(pages, npages);
870		break;
871	case IOPT_PAGES_ACCOUNT_MM:
872		rc = update_mm_locked_vm(pages, npages, inc, user);
873		break;
874	}
875	if (rc)
876		return rc;
877
878	pages->last_npinned = pages->npinned;
879	if (inc)
880		atomic64_add(npages, &pages->source_mm->pinned_vm);
881	else
882		atomic64_sub(npages, &pages->source_mm->pinned_vm);
883	return 0;
884}
885
886static void update_unpinned(struct iopt_pages *pages)
887{
888	if (WARN_ON(pages->npinned > pages->last_npinned))
889		return;
890	if (pages->npinned == pages->last_npinned)
891		return;
892	do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
893			 NULL);
894}
895
896/*
897 * Changes in the number of pages pinned is done after the pages have been read
898 * and processed. If the user lacked the limit then the error unwind will unpin
899 * everything that was just pinned. This is because it is expensive to calculate
900 * how many pages we have already pinned within a range to generate an accurate
901 * prediction in advance of doing the work to actually pin them.
902 */
903static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
904					 struct iopt_pages *pages)
905{
906	unsigned long npages;
907	bool inc;
908
909	lockdep_assert_held(&pages->mutex);
910
911	if (pages->npinned == pages->last_npinned)
912		return 0;
913
914	if (pages->npinned < pages->last_npinned) {
915		npages = pages->last_npinned - pages->npinned;
916		inc = false;
917	} else {
918		if (iommufd_should_fail())
919			return -ENOMEM;
920		npages = pages->npinned - pages->last_npinned;
921		inc = true;
922	}
923	return do_update_pinned(pages, npages, inc, user);
924}
925
926/*
927 * PFNs are stored in three places, in order of preference:
928 * - The iopt_pages xarray. This is only populated if there is a
929 *   iopt_pages_access
930 * - The iommu_domain under an area
931 * - The original PFN source, ie pages->source_mm
932 *
933 * This iterator reads the pfns optimizing to load according to the
934 * above order.
935 */
936struct pfn_reader {
937	struct iopt_pages *pages;
938	struct interval_tree_double_span_iter span;
939	struct pfn_batch batch;
940	unsigned long batch_start_index;
941	unsigned long batch_end_index;
942	unsigned long last_index;
943
944	struct pfn_reader_user user;
945};
946
947static int pfn_reader_update_pinned(struct pfn_reader *pfns)
948{
949	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
950}
951
952/*
953 * The batch can contain a mixture of pages that are still in use and pages that
954 * need to be unpinned. Unpin only pages that are not held anywhere else.
955 */
956static void pfn_reader_unpin(struct pfn_reader *pfns)
957{
958	unsigned long last = pfns->batch_end_index - 1;
959	unsigned long start = pfns->batch_start_index;
960	struct interval_tree_double_span_iter span;
961	struct iopt_pages *pages = pfns->pages;
962
963	lockdep_assert_held(&pages->mutex);
964
965	interval_tree_for_each_double_span(&span, &pages->access_itree,
966					   &pages->domains_itree, start, last) {
967		if (span.is_used)
968			continue;
969
970		batch_unpin(&pfns->batch, pages, span.start_hole - start,
971			    span.last_hole - span.start_hole + 1);
972	}
973}
974
975/* Process a single span to load it from the proper storage */
976static int pfn_reader_fill_span(struct pfn_reader *pfns)
977{
978	struct interval_tree_double_span_iter *span = &pfns->span;
979	unsigned long start_index = pfns->batch_end_index;
980	struct iopt_area *area;
981	int rc;
982
983	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
984	    WARN_ON(span->last_used < start_index))
985		return -EINVAL;
986
987	if (span->is_used == 1) {
988		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
989				  start_index, span->last_used);
990		return 0;
991	}
992
993	if (span->is_used == 2) {
994		/*
995		 * Pull as many pages from the first domain we find in the
996		 * target span. If it is too small then we will be called again
997		 * and we'll find another area.
998		 */
999		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1000		if (WARN_ON(!area))
1001			return -EINVAL;
1002
1003		/* The storage_domain cannot change without the pages mutex */
1004		batch_from_domain(
1005			&pfns->batch, area->storage_domain, area, start_index,
1006			min(iopt_area_last_index(area), span->last_used));
1007		return 0;
1008	}
1009
1010	if (start_index >= pfns->user.upages_end) {
1011		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1012					 span->last_hole);
1013		if (rc)
1014			return rc;
1015	}
1016
1017	batch_from_pages(&pfns->batch,
1018			 pfns->user.upages +
1019				 (start_index - pfns->user.upages_start),
1020			 pfns->user.upages_end - start_index);
1021	return 0;
1022}
1023
1024static bool pfn_reader_done(struct pfn_reader *pfns)
1025{
1026	return pfns->batch_start_index == pfns->last_index + 1;
1027}
1028
1029static int pfn_reader_next(struct pfn_reader *pfns)
1030{
1031	int rc;
1032
1033	batch_clear(&pfns->batch);
1034	pfns->batch_start_index = pfns->batch_end_index;
1035
1036	while (pfns->batch_end_index != pfns->last_index + 1) {
1037		unsigned int npfns = pfns->batch.total_pfns;
1038
1039		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1040		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1041			return -EINVAL;
1042
1043		rc = pfn_reader_fill_span(pfns);
1044		if (rc)
1045			return rc;
1046
1047		if (WARN_ON(!pfns->batch.total_pfns))
1048			return -EINVAL;
1049
1050		pfns->batch_end_index =
1051			pfns->batch_start_index + pfns->batch.total_pfns;
1052		if (pfns->batch_end_index == pfns->span.last_used + 1)
1053			interval_tree_double_span_iter_next(&pfns->span);
1054
1055		/* Batch is full */
1056		if (npfns == pfns->batch.total_pfns)
1057			return 0;
1058	}
1059	return 0;
1060}
1061
1062static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1063			   unsigned long start_index, unsigned long last_index)
1064{
1065	int rc;
1066
1067	lockdep_assert_held(&pages->mutex);
1068
1069	pfns->pages = pages;
1070	pfns->batch_start_index = start_index;
1071	pfns->batch_end_index = start_index;
1072	pfns->last_index = last_index;
1073	pfn_reader_user_init(&pfns->user, pages);
1074	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1075	if (rc)
1076		return rc;
1077	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1078					     &pages->domains_itree, start_index,
1079					     last_index);
1080	return 0;
1081}
1082
1083/*
1084 * There are many assertions regarding the state of pages->npinned vs
1085 * pages->last_pinned, for instance something like unmapping a domain must only
1086 * decrement the npinned, and pfn_reader_destroy() must be called only after all
1087 * the pins are updated. This is fine for success flows, but error flows
1088 * sometimes need to release the pins held inside the pfn_reader before going on
1089 * to complete unmapping and releasing pins held in domains.
1090 */
1091static void pfn_reader_release_pins(struct pfn_reader *pfns)
1092{
1093	struct iopt_pages *pages = pfns->pages;
1094
1095	if (pfns->user.upages_end > pfns->batch_end_index) {
1096		size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1097
1098		/* Any pages not transferred to the batch are just unpinned */
1099		unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1100						      pfns->user.upages_start),
1101				 npages);
1102		iopt_pages_sub_npinned(pages, npages);
1103		pfns->user.upages_end = pfns->batch_end_index;
1104	}
1105	if (pfns->batch_start_index != pfns->batch_end_index) {
1106		pfn_reader_unpin(pfns);
1107		pfns->batch_start_index = pfns->batch_end_index;
1108	}
1109}
1110
1111static void pfn_reader_destroy(struct pfn_reader *pfns)
1112{
1113	struct iopt_pages *pages = pfns->pages;
1114
1115	pfn_reader_release_pins(pfns);
1116	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1117	batch_destroy(&pfns->batch, NULL);
1118	WARN_ON(pages->last_npinned != pages->npinned);
1119}
1120
1121static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1122			    unsigned long start_index, unsigned long last_index)
1123{
1124	int rc;
1125
1126	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1127	    WARN_ON(last_index < start_index))
1128		return -EINVAL;
1129
1130	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1131	if (rc)
1132		return rc;
1133	rc = pfn_reader_next(pfns);
1134	if (rc) {
1135		pfn_reader_destroy(pfns);
1136		return rc;
1137	}
1138	return 0;
1139}
1140
1141struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1142				    bool writable)
1143{
1144	struct iopt_pages *pages;
1145	unsigned long end;
1146
1147	/*
1148	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1149	 * below from overflow
1150	 */
1151	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1152		return ERR_PTR(-EINVAL);
1153
1154	if (check_add_overflow((unsigned long)uptr, length, &end))
1155		return ERR_PTR(-EOVERFLOW);
1156
1157	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1158	if (!pages)
1159		return ERR_PTR(-ENOMEM);
1160
1161	kref_init(&pages->kref);
1162	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1163	mutex_init(&pages->mutex);
1164	pages->source_mm = current->mm;
1165	mmgrab(pages->source_mm);
1166	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1167	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1168	pages->access_itree = RB_ROOT_CACHED;
1169	pages->domains_itree = RB_ROOT_CACHED;
1170	pages->writable = writable;
1171	if (capable(CAP_IPC_LOCK))
1172		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1173	else
1174		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1175	pages->source_task = current->group_leader;
1176	get_task_struct(current->group_leader);
1177	pages->source_user = get_uid(current_user());
1178	return pages;
1179}
1180
1181void iopt_release_pages(struct kref *kref)
1182{
1183	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1184
1185	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1186	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1187	WARN_ON(pages->npinned);
1188	WARN_ON(!xa_empty(&pages->pinned_pfns));
1189	mmdrop(pages->source_mm);
1190	mutex_destroy(&pages->mutex);
1191	put_task_struct(pages->source_task);
1192	free_uid(pages->source_user);
1193	kfree(pages);
1194}
1195
1196static void
1197iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1198		       struct iopt_pages *pages, struct iommu_domain *domain,
1199		       unsigned long start_index, unsigned long last_index,
1200		       unsigned long *unmapped_end_index,
1201		       unsigned long real_last_index)
1202{
1203	while (start_index <= last_index) {
1204		unsigned long batch_last_index;
1205
1206		if (*unmapped_end_index <= last_index) {
1207			unsigned long start =
1208				max(start_index, *unmapped_end_index);
1209
1210			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1211			    batch->total_pfns)
1212				WARN_ON(*unmapped_end_index -
1213						batch->total_pfns !=
1214					start_index);
1215			batch_from_domain(batch, domain, area, start,
1216					  last_index);
1217			batch_last_index = start_index + batch->total_pfns - 1;
1218		} else {
1219			batch_last_index = last_index;
1220		}
1221
1222		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1223			WARN_ON(batch_last_index > real_last_index);
1224
1225		/*
1226		 * unmaps must always 'cut' at a place where the pfns are not
1227		 * contiguous to pair with the maps that always install
1228		 * contiguous pages. Thus, if we have to stop unpinning in the
1229		 * middle of the domains we need to keep reading pfns until we
1230		 * find a cut point to do the unmap. The pfns we read are
1231		 * carried over and either skipped or integrated into the next
1232		 * batch.
1233		 */
1234		if (batch_last_index == last_index &&
1235		    last_index != real_last_index)
1236			batch_from_domain_continue(batch, domain, area,
1237						   last_index + 1,
1238						   real_last_index);
1239
1240		if (*unmapped_end_index <= batch_last_index) {
1241			iopt_area_unmap_domain_range(
1242				area, domain, *unmapped_end_index,
1243				start_index + batch->total_pfns - 1);
1244			*unmapped_end_index = start_index + batch->total_pfns;
1245		}
1246
1247		/* unpin must follow unmap */
1248		batch_unpin(batch, pages, 0,
1249			    batch_last_index - start_index + 1);
1250		start_index = batch_last_index + 1;
1251
1252		batch_clear_carry(batch,
1253				  *unmapped_end_index - batch_last_index - 1);
1254	}
1255}
1256
1257static void __iopt_area_unfill_domain(struct iopt_area *area,
1258				      struct iopt_pages *pages,
1259				      struct iommu_domain *domain,
1260				      unsigned long last_index)
1261{
1262	struct interval_tree_double_span_iter span;
1263	unsigned long start_index = iopt_area_index(area);
1264	unsigned long unmapped_end_index = start_index;
1265	u64 backup[BATCH_BACKUP_SIZE];
1266	struct pfn_batch batch;
1267
1268	lockdep_assert_held(&pages->mutex);
1269
1270	/*
1271	 * For security we must not unpin something that is still DMA mapped,
1272	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1273	 * This creates a complexity where we need to skip over unpinning pages
1274	 * held in the xarray, but continue to unmap from the domain.
1275	 *
1276	 * The domain unmap cannot stop in the middle of a contiguous range of
1277	 * PFNs. To solve this problem the unpinning step will read ahead to the
1278	 * end of any contiguous span, unmap that whole span, and then only
1279	 * unpin the leading part that does not have any accesses. The residual
1280	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1281	 * batch as they are moved to the front of the PFN list and continue on
1282	 * to the next iteration(s).
1283	 */
1284	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1285	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1286					   &pages->access_itree, start_index,
1287					   last_index) {
1288		if (span.is_used) {
1289			batch_skip_carry(&batch,
1290					 span.last_used - span.start_used + 1);
1291			continue;
1292		}
1293		iopt_area_unpin_domain(&batch, area, pages, domain,
1294				       span.start_hole, span.last_hole,
1295				       &unmapped_end_index, last_index);
1296	}
1297	/*
1298	 * If the range ends in a access then we do the residual unmap without
1299	 * any unpins.
1300	 */
1301	if (unmapped_end_index != last_index + 1)
1302		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1303					     last_index);
1304	WARN_ON(batch.total_pfns);
1305	batch_destroy(&batch, backup);
1306	update_unpinned(pages);
1307}
1308
1309static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1310					    struct iopt_pages *pages,
1311					    struct iommu_domain *domain,
1312					    unsigned long end_index)
1313{
1314	if (end_index != iopt_area_index(area))
1315		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1316}
1317
1318/**
1319 * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1320 * @area: The IOVA range to unmap
1321 * @domain: The domain to unmap
1322 *
1323 * The caller must know that unpinning is not required, usually because there
1324 * are other domains in the iopt.
1325 */
1326void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1327{
1328	iommu_unmap_nofail(domain, iopt_area_iova(area),
1329			   iopt_area_length(area));
1330}
1331
1332/**
1333 * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1334 * @area: IOVA area to use
1335 * @pages: page supplier for the area (area->pages is NULL)
1336 * @domain: Domain to unmap from
1337 *
1338 * The domain should be removed from the domains_itree before calling. The
1339 * domain will always be unmapped, but the PFNs may not be unpinned if there are
1340 * still accesses.
1341 */
1342void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1343			     struct iommu_domain *domain)
1344{
1345	__iopt_area_unfill_domain(area, pages, domain,
1346				  iopt_area_last_index(area));
1347}
1348
1349/**
1350 * iopt_area_fill_domain() - Map PFNs from the area into a domain
1351 * @area: IOVA area to use
1352 * @domain: Domain to load PFNs into
1353 *
1354 * Read the pfns from the area's underlying iopt_pages and map them into the
1355 * given domain. Called when attaching a new domain to an io_pagetable.
1356 */
1357int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1358{
1359	unsigned long done_end_index;
1360	struct pfn_reader pfns;
1361	int rc;
1362
1363	lockdep_assert_held(&area->pages->mutex);
1364
1365	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1366			      iopt_area_last_index(area));
1367	if (rc)
1368		return rc;
1369
1370	while (!pfn_reader_done(&pfns)) {
1371		done_end_index = pfns.batch_start_index;
1372		rc = batch_to_domain(&pfns.batch, domain, area,
1373				     pfns.batch_start_index);
1374		if (rc)
1375			goto out_unmap;
1376		done_end_index = pfns.batch_end_index;
1377
1378		rc = pfn_reader_next(&pfns);
1379		if (rc)
1380			goto out_unmap;
1381	}
1382
1383	rc = pfn_reader_update_pinned(&pfns);
1384	if (rc)
1385		goto out_unmap;
1386	goto out_destroy;
1387
1388out_unmap:
1389	pfn_reader_release_pins(&pfns);
1390	iopt_area_unfill_partial_domain(area, area->pages, domain,
1391					done_end_index);
1392out_destroy:
1393	pfn_reader_destroy(&pfns);
1394	return rc;
1395}
1396
1397/**
1398 * iopt_area_fill_domains() - Install PFNs into the area's domains
1399 * @area: The area to act on
1400 * @pages: The pages associated with the area (area->pages is NULL)
1401 *
1402 * Called during area creation. The area is freshly created and not inserted in
1403 * the domains_itree yet. PFNs are read and loaded into every domain held in the
1404 * area's io_pagetable and the area is installed in the domains_itree.
1405 *
1406 * On failure all domains are left unchanged.
1407 */
1408int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1409{
1410	unsigned long done_first_end_index;
1411	unsigned long done_all_end_index;
1412	struct iommu_domain *domain;
1413	unsigned long unmap_index;
1414	struct pfn_reader pfns;
1415	unsigned long index;
1416	int rc;
1417
1418	lockdep_assert_held(&area->iopt->domains_rwsem);
1419
1420	if (xa_empty(&area->iopt->domains))
1421		return 0;
1422
1423	mutex_lock(&pages->mutex);
1424	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1425			      iopt_area_last_index(area));
1426	if (rc)
1427		goto out_unlock;
1428
1429	while (!pfn_reader_done(&pfns)) {
1430		done_first_end_index = pfns.batch_end_index;
1431		done_all_end_index = pfns.batch_start_index;
1432		xa_for_each(&area->iopt->domains, index, domain) {
1433			rc = batch_to_domain(&pfns.batch, domain, area,
1434					     pfns.batch_start_index);
1435			if (rc)
1436				goto out_unmap;
1437		}
1438		done_all_end_index = done_first_end_index;
1439
1440		rc = pfn_reader_next(&pfns);
1441		if (rc)
1442			goto out_unmap;
1443	}
1444	rc = pfn_reader_update_pinned(&pfns);
1445	if (rc)
1446		goto out_unmap;
1447
1448	area->storage_domain = xa_load(&area->iopt->domains, 0);
1449	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1450	goto out_destroy;
1451
1452out_unmap:
1453	pfn_reader_release_pins(&pfns);
1454	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1455		unsigned long end_index;
1456
1457		if (unmap_index < index)
1458			end_index = done_first_end_index;
1459		else
1460			end_index = done_all_end_index;
1461
1462		/*
1463		 * The area is not yet part of the domains_itree so we have to
1464		 * manage the unpinning specially. The last domain does the
1465		 * unpin, every other domain is just unmapped.
1466		 */
1467		if (unmap_index != area->iopt->next_domain_id - 1) {
1468			if (end_index != iopt_area_index(area))
1469				iopt_area_unmap_domain_range(
1470					area, domain, iopt_area_index(area),
1471					end_index - 1);
1472		} else {
1473			iopt_area_unfill_partial_domain(area, pages, domain,
1474							end_index);
1475		}
1476	}
1477out_destroy:
1478	pfn_reader_destroy(&pfns);
1479out_unlock:
1480	mutex_unlock(&pages->mutex);
1481	return rc;
1482}
1483
1484/**
1485 * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1486 * @area: The area to act on
1487 * @pages: The pages associated with the area (area->pages is NULL)
1488 *
1489 * Called during area destruction. This unmaps the iova's covered by all the
1490 * area's domains and releases the PFNs.
1491 */
1492void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1493{
1494	struct io_pagetable *iopt = area->iopt;
1495	struct iommu_domain *domain;
1496	unsigned long index;
1497
1498	lockdep_assert_held(&iopt->domains_rwsem);
1499
1500	mutex_lock(&pages->mutex);
1501	if (!area->storage_domain)
1502		goto out_unlock;
1503
1504	xa_for_each(&iopt->domains, index, domain)
1505		if (domain != area->storage_domain)
1506			iopt_area_unmap_domain_range(
1507				area, domain, iopt_area_index(area),
1508				iopt_area_last_index(area));
1509
1510	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1511		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1512	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1513	iopt_area_unfill_domain(area, pages, area->storage_domain);
1514	area->storage_domain = NULL;
1515out_unlock:
1516	mutex_unlock(&pages->mutex);
1517}
1518
1519static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1520				    struct iopt_pages *pages,
1521				    unsigned long start_index,
1522				    unsigned long end_index)
1523{
1524	while (start_index <= end_index) {
1525		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1526					end_index);
1527		batch_unpin(batch, pages, 0, batch->total_pfns);
1528		start_index += batch->total_pfns;
1529		batch_clear(batch);
1530	}
1531}
1532
1533/**
1534 * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1535 * @pages: The pages to act on
1536 * @start_index: Starting PFN index
1537 * @last_index: Last PFN index
1538 *
1539 * Called when an iopt_pages_access is removed, removes pages from the itree.
1540 * The access should already be removed from the access_itree.
1541 */
1542void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1543			      unsigned long start_index,
1544			      unsigned long last_index)
1545{
1546	struct interval_tree_double_span_iter span;
1547	u64 backup[BATCH_BACKUP_SIZE];
1548	struct pfn_batch batch;
1549	bool batch_inited = false;
1550
1551	lockdep_assert_held(&pages->mutex);
1552
1553	interval_tree_for_each_double_span(&span, &pages->access_itree,
1554					   &pages->domains_itree, start_index,
1555					   last_index) {
1556		if (!span.is_used) {
1557			if (!batch_inited) {
1558				batch_init_backup(&batch,
1559						  last_index - start_index + 1,
1560						  backup, sizeof(backup));
1561				batch_inited = true;
1562			}
1563			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1564						span.last_hole);
1565		} else if (span.is_used == 2) {
1566			/* Covered by a domain */
1567			clear_xarray(&pages->pinned_pfns, span.start_used,
1568				     span.last_used);
1569		}
1570		/* Otherwise covered by an existing access */
1571	}
1572	if (batch_inited)
1573		batch_destroy(&batch, backup);
1574	update_unpinned(pages);
1575}
1576
1577/**
1578 * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1579 * @pages: The pages to act on
1580 * @start_index: The first page index in the range
1581 * @last_index: The last page index in the range
1582 * @out_pages: The output array to return the pages
1583 *
1584 * This can be called if the caller is holding a refcount on an
1585 * iopt_pages_access that is known to have already been filled. It quickly reads
1586 * the pages directly from the xarray.
1587 *
1588 * This is part of the SW iommu interface to read pages for in-kernel use.
1589 */
1590void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1591				 unsigned long start_index,
1592				 unsigned long last_index,
1593				 struct page **out_pages)
1594{
1595	XA_STATE(xas, &pages->pinned_pfns, start_index);
1596	void *entry;
1597
1598	rcu_read_lock();
1599	while (start_index <= last_index) {
1600		entry = xas_next(&xas);
1601		if (xas_retry(&xas, entry))
1602			continue;
1603		WARN_ON(!xa_is_value(entry));
1604		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1605		start_index++;
1606	}
1607	rcu_read_unlock();
1608}
1609
1610static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1611				       unsigned long start_index,
1612				       unsigned long last_index,
1613				       struct page **out_pages)
1614{
1615	while (start_index != last_index + 1) {
1616		unsigned long domain_last;
1617		struct iopt_area *area;
1618
1619		area = iopt_pages_find_domain_area(pages, start_index);
1620		if (WARN_ON(!area))
1621			return -EINVAL;
1622
1623		domain_last = min(iopt_area_last_index(area), last_index);
1624		out_pages = raw_pages_from_domain(area->storage_domain, area,
1625						  start_index, domain_last,
1626						  out_pages);
1627		start_index = domain_last + 1;
1628	}
1629	return 0;
1630}
1631
1632static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1633				   struct pfn_reader_user *user,
1634				   unsigned long start_index,
1635				   unsigned long last_index,
1636				   struct page **out_pages)
1637{
1638	unsigned long cur_index = start_index;
1639	int rc;
1640
1641	while (cur_index != last_index + 1) {
1642		user->upages = out_pages + (cur_index - start_index);
1643		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1644		if (rc)
1645			goto out_unpin;
1646		cur_index = user->upages_end;
1647	}
1648	return 0;
1649
1650out_unpin:
1651	if (start_index != cur_index)
1652		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1653				     out_pages);
1654	return rc;
1655}
1656
1657/**
1658 * iopt_pages_fill_xarray() - Read PFNs
1659 * @pages: The pages to act on
1660 * @start_index: The first page index in the range
1661 * @last_index: The last page index in the range
1662 * @out_pages: The output array to return the pages, may be NULL
1663 *
1664 * This populates the xarray and returns the pages in out_pages. As the slow
1665 * path this is able to copy pages from other storage tiers into the xarray.
1666 *
1667 * On failure the xarray is left unchanged.
1668 *
1669 * This is part of the SW iommu interface to read pages for in-kernel use.
1670 */
1671int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1672			   unsigned long last_index, struct page **out_pages)
1673{
1674	struct interval_tree_double_span_iter span;
1675	unsigned long xa_end = start_index;
1676	struct pfn_reader_user user;
1677	int rc;
1678
1679	lockdep_assert_held(&pages->mutex);
1680
1681	pfn_reader_user_init(&user, pages);
1682	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1683	interval_tree_for_each_double_span(&span, &pages->access_itree,
1684					   &pages->domains_itree, start_index,
1685					   last_index) {
1686		struct page **cur_pages;
1687
1688		if (span.is_used == 1) {
1689			cur_pages = out_pages + (span.start_used - start_index);
1690			iopt_pages_fill_from_xarray(pages, span.start_used,
1691						    span.last_used, cur_pages);
1692			continue;
1693		}
1694
1695		if (span.is_used == 2) {
1696			cur_pages = out_pages + (span.start_used - start_index);
1697			iopt_pages_fill_from_domain(pages, span.start_used,
1698						    span.last_used, cur_pages);
1699			rc = pages_to_xarray(&pages->pinned_pfns,
1700					     span.start_used, span.last_used,
1701					     cur_pages);
1702			if (rc)
1703				goto out_clean_xa;
1704			xa_end = span.last_used + 1;
1705			continue;
1706		}
1707
1708		/* hole */
1709		cur_pages = out_pages + (span.start_hole - start_index);
1710		rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1711					     span.last_hole, cur_pages);
1712		if (rc)
1713			goto out_clean_xa;
1714		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1715				     span.last_hole, cur_pages);
1716		if (rc) {
1717			iopt_pages_err_unpin(pages, span.start_hole,
1718					     span.last_hole, cur_pages);
1719			goto out_clean_xa;
1720		}
1721		xa_end = span.last_hole + 1;
1722	}
1723	rc = pfn_reader_user_update_pinned(&user, pages);
1724	if (rc)
1725		goto out_clean_xa;
1726	user.upages = NULL;
1727	pfn_reader_user_destroy(&user, pages);
1728	return 0;
1729
1730out_clean_xa:
1731	if (start_index != xa_end)
1732		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1733	user.upages = NULL;
1734	pfn_reader_user_destroy(&user, pages);
1735	return rc;
1736}
1737
1738/*
1739 * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1740 * do every scenario and is fully consistent with what an iommu_domain would
1741 * see.
1742 */
1743static int iopt_pages_rw_slow(struct iopt_pages *pages,
1744			      unsigned long start_index,
1745			      unsigned long last_index, unsigned long offset,
1746			      void *data, unsigned long length,
1747			      unsigned int flags)
1748{
1749	struct pfn_reader pfns;
1750	int rc;
1751
1752	mutex_lock(&pages->mutex);
1753
1754	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1755	if (rc)
1756		goto out_unlock;
1757
1758	while (!pfn_reader_done(&pfns)) {
1759		unsigned long done;
1760
1761		done = batch_rw(&pfns.batch, data, offset, length, flags);
1762		data += done;
1763		length -= done;
1764		offset = 0;
1765		pfn_reader_unpin(&pfns);
1766
1767		rc = pfn_reader_next(&pfns);
1768		if (rc)
1769			goto out_destroy;
1770	}
1771	if (WARN_ON(length != 0))
1772		rc = -EINVAL;
1773out_destroy:
1774	pfn_reader_destroy(&pfns);
1775out_unlock:
1776	mutex_unlock(&pages->mutex);
1777	return rc;
1778}
1779
1780/*
1781 * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1782 * memory allocations or interval tree searches.
1783 */
1784static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1785			      unsigned long offset, void *data,
1786			      unsigned long length, unsigned int flags)
1787{
1788	struct page *page = NULL;
1789	int rc;
1790
1791	if (!mmget_not_zero(pages->source_mm))
1792		return iopt_pages_rw_slow(pages, index, index, offset, data,
1793					  length, flags);
1794
1795	if (iommufd_should_fail()) {
1796		rc = -EINVAL;
1797		goto out_mmput;
1798	}
1799
1800	mmap_read_lock(pages->source_mm);
1801	rc = pin_user_pages_remote(
1802		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1803		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1804		NULL);
1805	mmap_read_unlock(pages->source_mm);
1806	if (rc != 1) {
1807		if (WARN_ON(rc >= 0))
1808			rc = -EINVAL;
1809		goto out_mmput;
1810	}
1811	copy_data_page(page, data, offset, length, flags);
1812	unpin_user_page(page);
1813	rc = 0;
1814
1815out_mmput:
1816	mmput(pages->source_mm);
1817	return rc;
1818}
1819
1820/**
1821 * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1822 * @pages: pages to act on
1823 * @start_byte: First byte of pages to copy to/from
1824 * @data: Kernel buffer to get/put the data
1825 * @length: Number of bytes to copy
1826 * @flags: IOMMUFD_ACCESS_RW_* flags
1827 *
1828 * This will find each page in the range, kmap it and then memcpy to/from
1829 * the given kernel buffer.
1830 */
1831int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1832			 void *data, unsigned long length, unsigned int flags)
1833{
1834	unsigned long start_index = start_byte / PAGE_SIZE;
1835	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1836	bool change_mm = current->mm != pages->source_mm;
1837	int rc = 0;
1838
1839	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1840	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1841		change_mm = true;
1842
1843	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1844		return -EPERM;
1845
1846	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1847		if (start_index == last_index)
1848			return iopt_pages_rw_page(pages, start_index,
1849						  start_byte % PAGE_SIZE, data,
1850						  length, flags);
1851		return iopt_pages_rw_slow(pages, start_index, last_index,
1852					  start_byte % PAGE_SIZE, data, length,
1853					  flags);
1854	}
1855
1856	/*
1857	 * Try to copy using copy_to_user(). We do this as a fast path and
1858	 * ignore any pinning inconsistencies, unlike a real DMA path.
1859	 */
1860	if (change_mm) {
1861		if (!mmget_not_zero(pages->source_mm))
1862			return iopt_pages_rw_slow(pages, start_index,
1863						  last_index,
1864						  start_byte % PAGE_SIZE, data,
1865						  length, flags);
1866		kthread_use_mm(pages->source_mm);
1867	}
1868
1869	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1870		if (copy_to_user(pages->uptr + start_byte, data, length))
1871			rc = -EFAULT;
1872	} else {
1873		if (copy_from_user(data, pages->uptr + start_byte, length))
1874			rc = -EFAULT;
1875	}
1876
1877	if (change_mm) {
1878		kthread_unuse_mm(pages->source_mm);
1879		mmput(pages->source_mm);
1880	}
1881
1882	return rc;
1883}
1884
1885static struct iopt_pages_access *
1886iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1887			    unsigned long last)
1888{
1889	struct interval_tree_node *node;
1890
1891	lockdep_assert_held(&pages->mutex);
1892
1893	/* There can be overlapping ranges in this interval tree */
1894	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1895	     node; node = interval_tree_iter_next(node, index, last))
1896		if (node->start == index && node->last == last)
1897			return container_of(node, struct iopt_pages_access,
1898					    node);
1899	return NULL;
1900}
1901
1902/**
1903 * iopt_area_add_access() - Record an in-knerel access for PFNs
1904 * @area: The source of PFNs
1905 * @start_index: First page index
1906 * @last_index: Inclusive last page index
1907 * @out_pages: Output list of struct page's representing the PFNs
1908 * @flags: IOMMUFD_ACCESS_RW_* flags
1909 *
1910 * Record that an in-kernel access will be accessing the pages, ensure they are
1911 * pinned, and return the PFNs as a simple list of 'struct page *'.
1912 *
1913 * This should be undone through a matching call to iopt_area_remove_access()
1914 */
1915int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1916			  unsigned long last_index, struct page **out_pages,
1917			  unsigned int flags)
1918{
1919	struct iopt_pages *pages = area->pages;
1920	struct iopt_pages_access *access;
1921	int rc;
1922
1923	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1924		return -EPERM;
1925
1926	mutex_lock(&pages->mutex);
1927	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1928	if (access) {
1929		area->num_accesses++;
1930		access->users++;
1931		iopt_pages_fill_from_xarray(pages, start_index, last_index,
1932					    out_pages);
1933		mutex_unlock(&pages->mutex);
1934		return 0;
1935	}
1936
1937	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1938	if (!access) {
1939		rc = -ENOMEM;
1940		goto err_unlock;
1941	}
1942
1943	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1944	if (rc)
1945		goto err_free;
1946
1947	access->node.start = start_index;
1948	access->node.last = last_index;
1949	access->users = 1;
1950	area->num_accesses++;
1951	interval_tree_insert(&access->node, &pages->access_itree);
1952	mutex_unlock(&pages->mutex);
1953	return 0;
1954
1955err_free:
1956	kfree(access);
1957err_unlock:
1958	mutex_unlock(&pages->mutex);
1959	return rc;
1960}
1961
1962/**
1963 * iopt_area_remove_access() - Release an in-kernel access for PFNs
1964 * @area: The source of PFNs
1965 * @start_index: First page index
1966 * @last_index: Inclusive last page index
1967 *
1968 * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1969 * must stop using the PFNs before calling this.
1970 */
1971void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1972			     unsigned long last_index)
1973{
1974	struct iopt_pages *pages = area->pages;
1975	struct iopt_pages_access *access;
1976
1977	mutex_lock(&pages->mutex);
1978	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1979	if (WARN_ON(!access))
1980		goto out_unlock;
1981
1982	WARN_ON(area->num_accesses == 0 || access->users == 0);
1983	area->num_accesses--;
1984	access->users--;
1985	if (access->users)
1986		goto out_unlock;
1987
1988	interval_tree_remove(&access->node, &pages->access_itree);
1989	iopt_pages_unfill_xarray(pages, start_index, last_index);
1990	kfree(access);
1991out_unlock:
1992	mutex_unlock(&pages->mutex);
1993}
1994