1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3 *
4 * Kernel side components to support tools/testing/selftests/iommu
5 */
6#include <linux/slab.h>
7#include <linux/iommu.h>
8#include <linux/xarray.h>
9#include <linux/file.h>
10#include <linux/anon_inodes.h>
11#include <linux/fault-inject.h>
12#include <linux/platform_device.h>
13#include <uapi/linux/iommufd.h>
14
15#include "../iommu-priv.h"
16#include "io_pagetable.h"
17#include "iommufd_private.h"
18#include "iommufd_test.h"
19
20static DECLARE_FAULT_ATTR(fail_iommufd);
21static struct dentry *dbgfs_root;
22static struct platform_device *selftest_iommu_dev;
23static const struct iommu_ops mock_ops;
24static struct iommu_domain_ops domain_nested_ops;
25
26size_t iommufd_test_memory_limit = 65536;
27
28struct mock_bus_type {
29	struct bus_type bus;
30	struct notifier_block nb;
31};
32
33static struct mock_bus_type iommufd_mock_bus_type = {
34	.bus = {
35		.name = "iommufd_mock",
36	},
37};
38
39static DEFINE_IDA(mock_dev_ida);
40
41enum {
42	MOCK_DIRTY_TRACK = 1,
43	MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
44	MOCK_HUGE_PAGE_SIZE = 512 * MOCK_IO_PAGE_SIZE,
45
46	/*
47	 * Like a real page table alignment requires the low bits of the address
48	 * to be zero. xarray also requires the high bit to be zero, so we store
49	 * the pfns shifted. The upper bits are used for metadata.
50	 */
51	MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
52
53	_MOCK_PFN_START = MOCK_PFN_MASK + 1,
54	MOCK_PFN_START_IOVA = _MOCK_PFN_START,
55	MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
56	MOCK_PFN_DIRTY_IOVA = _MOCK_PFN_START << 1,
57	MOCK_PFN_HUGE_IOVA = _MOCK_PFN_START << 2,
58};
59
60/*
61 * Syzkaller has trouble randomizing the correct iova to use since it is linked
62 * to the map ioctl's output, and it has no ide about that. So, simplify things.
63 * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
64 * value. This has a much smaller randomization space and syzkaller can hit it.
65 */
66static unsigned long __iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
67						  u64 *iova)
68{
69	struct syz_layout {
70		__u32 nth_area;
71		__u32 offset;
72	};
73	struct syz_layout *syz = (void *)iova;
74	unsigned int nth = syz->nth_area;
75	struct iopt_area *area;
76
77	down_read(&iopt->iova_rwsem);
78	for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
79	     area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
80		if (nth == 0) {
81			up_read(&iopt->iova_rwsem);
82			return iopt_area_iova(area) + syz->offset;
83		}
84		nth--;
85	}
86	up_read(&iopt->iova_rwsem);
87
88	return 0;
89}
90
91static unsigned long iommufd_test_syz_conv_iova(struct iommufd_access *access,
92						u64 *iova)
93{
94	unsigned long ret;
95
96	mutex_lock(&access->ioas_lock);
97	if (!access->ioas) {
98		mutex_unlock(&access->ioas_lock);
99		return 0;
100	}
101	ret = __iommufd_test_syz_conv_iova(&access->ioas->iopt, iova);
102	mutex_unlock(&access->ioas_lock);
103	return ret;
104}
105
106void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
107				   unsigned int ioas_id, u64 *iova, u32 *flags)
108{
109	struct iommufd_ioas *ioas;
110
111	if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
112		return;
113	*flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
114
115	ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
116	if (IS_ERR(ioas))
117		return;
118	*iova = __iommufd_test_syz_conv_iova(&ioas->iopt, iova);
119	iommufd_put_object(ucmd->ictx, &ioas->obj);
120}
121
122struct mock_iommu_domain {
123	unsigned long flags;
124	struct iommu_domain domain;
125	struct xarray pfns;
126};
127
128struct mock_iommu_domain_nested {
129	struct iommu_domain domain;
130	struct mock_iommu_domain *parent;
131	u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
132};
133
134enum selftest_obj_type {
135	TYPE_IDEV,
136};
137
138struct mock_dev {
139	struct device dev;
140	unsigned long flags;
141	int id;
142};
143
144struct selftest_obj {
145	struct iommufd_object obj;
146	enum selftest_obj_type type;
147
148	union {
149		struct {
150			struct iommufd_device *idev;
151			struct iommufd_ctx *ictx;
152			struct mock_dev *mock_dev;
153		} idev;
154	};
155};
156
157static int mock_domain_nop_attach(struct iommu_domain *domain,
158				  struct device *dev)
159{
160	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
161
162	if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
163		return -EINVAL;
164
165	return 0;
166}
167
168static const struct iommu_domain_ops mock_blocking_ops = {
169	.attach_dev = mock_domain_nop_attach,
170};
171
172static struct iommu_domain mock_blocking_domain = {
173	.type = IOMMU_DOMAIN_BLOCKED,
174	.ops = &mock_blocking_ops,
175};
176
177static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
178{
179	struct iommu_test_hw_info *info;
180
181	info = kzalloc(sizeof(*info), GFP_KERNEL);
182	if (!info)
183		return ERR_PTR(-ENOMEM);
184
185	info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
186	*length = sizeof(*info);
187	*type = IOMMU_HW_INFO_TYPE_SELFTEST;
188
189	return info;
190}
191
192static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
193					  bool enable)
194{
195	struct mock_iommu_domain *mock =
196		container_of(domain, struct mock_iommu_domain, domain);
197	unsigned long flags = mock->flags;
198
199	if (enable && !domain->dirty_ops)
200		return -EINVAL;
201
202	/* No change? */
203	if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
204		return 0;
205
206	flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);
207
208	mock->flags = flags;
209	return 0;
210}
211
212static bool mock_test_and_clear_dirty(struct mock_iommu_domain *mock,
213				      unsigned long iova, size_t page_size,
214				      unsigned long flags)
215{
216	unsigned long cur, end = iova + page_size - 1;
217	bool dirty = false;
218	void *ent, *old;
219
220	for (cur = iova; cur < end; cur += MOCK_IO_PAGE_SIZE) {
221		ent = xa_load(&mock->pfns, cur / MOCK_IO_PAGE_SIZE);
222		if (!ent || !(xa_to_value(ent) & MOCK_PFN_DIRTY_IOVA))
223			continue;
224
225		dirty = true;
226		/* Clear dirty */
227		if (!(flags & IOMMU_DIRTY_NO_CLEAR)) {
228			unsigned long val;
229
230			val = xa_to_value(ent) & ~MOCK_PFN_DIRTY_IOVA;
231			old = xa_store(&mock->pfns, cur / MOCK_IO_PAGE_SIZE,
232				       xa_mk_value(val), GFP_KERNEL);
233			WARN_ON_ONCE(ent != old);
234		}
235	}
236
237	return dirty;
238}
239
240static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
241					    unsigned long iova, size_t size,
242					    unsigned long flags,
243					    struct iommu_dirty_bitmap *dirty)
244{
245	struct mock_iommu_domain *mock =
246		container_of(domain, struct mock_iommu_domain, domain);
247	unsigned long end = iova + size;
248	void *ent;
249
250	if (!(mock->flags & MOCK_DIRTY_TRACK) && dirty->bitmap)
251		return -EINVAL;
252
253	do {
254		unsigned long pgsize = MOCK_IO_PAGE_SIZE;
255		unsigned long head;
256
257		ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
258		if (!ent) {
259			iova += pgsize;
260			continue;
261		}
262
263		if (xa_to_value(ent) & MOCK_PFN_HUGE_IOVA)
264			pgsize = MOCK_HUGE_PAGE_SIZE;
265		head = iova & ~(pgsize - 1);
266
267		/* Clear dirty */
268		if (mock_test_and_clear_dirty(mock, head, pgsize, flags))
269			iommu_dirty_bitmap_record(dirty, head, pgsize);
270		iova = head + pgsize;
271	} while (iova < end);
272
273	return 0;
274}
275
276const struct iommu_dirty_ops dirty_ops = {
277	.set_dirty_tracking = mock_domain_set_dirty_tracking,
278	.read_and_clear_dirty = mock_domain_read_and_clear_dirty,
279};
280
281static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
282{
283	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
284	struct mock_iommu_domain *mock;
285
286	mock = kzalloc(sizeof(*mock), GFP_KERNEL);
287	if (!mock)
288		return NULL;
289	mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
290	mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
291	mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
292	if (dev && mdev->flags & MOCK_FLAGS_DEVICE_HUGE_IOVA)
293		mock->domain.pgsize_bitmap |= MOCK_HUGE_PAGE_SIZE;
294	mock->domain.ops = mock_ops.default_domain_ops;
295	mock->domain.type = IOMMU_DOMAIN_UNMANAGED;
296	xa_init(&mock->pfns);
297	return &mock->domain;
298}
299
300static struct iommu_domain *
301__mock_domain_alloc_nested(struct mock_iommu_domain *mock_parent,
302			   const struct iommu_hwpt_selftest *user_cfg)
303{
304	struct mock_iommu_domain_nested *mock_nested;
305	int i;
306
307	mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
308	if (!mock_nested)
309		return ERR_PTR(-ENOMEM);
310	mock_nested->parent = mock_parent;
311	mock_nested->domain.ops = &domain_nested_ops;
312	mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
313	for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
314		mock_nested->iotlb[i] = user_cfg->iotlb;
315	return &mock_nested->domain;
316}
317
318static struct iommu_domain *
319mock_domain_alloc_user(struct device *dev, u32 flags,
320		       struct iommu_domain *parent,
321		       const struct iommu_user_data *user_data)
322{
323	struct mock_iommu_domain *mock_parent;
324	struct iommu_hwpt_selftest user_cfg;
325	int rc;
326
327	/* must be mock_domain */
328	if (!parent) {
329		struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
330		bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
331		bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
332		struct iommu_domain *domain;
333
334		if (flags & (~(IOMMU_HWPT_ALLOC_NEST_PARENT |
335			       IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
336			return ERR_PTR(-EOPNOTSUPP);
337		if (user_data || (has_dirty_flag && no_dirty_ops))
338			return ERR_PTR(-EOPNOTSUPP);
339		domain = mock_domain_alloc_paging(dev);
340		if (!domain)
341			return ERR_PTR(-ENOMEM);
342		if (has_dirty_flag)
343			container_of(domain, struct mock_iommu_domain, domain)
344				->domain.dirty_ops = &dirty_ops;
345		return domain;
346	}
347
348	/* must be mock_domain_nested */
349	if (user_data->type != IOMMU_HWPT_DATA_SELFTEST || flags)
350		return ERR_PTR(-EOPNOTSUPP);
351	if (!parent || parent->ops != mock_ops.default_domain_ops)
352		return ERR_PTR(-EINVAL);
353
354	mock_parent = container_of(parent, struct mock_iommu_domain, domain);
355	if (!mock_parent)
356		return ERR_PTR(-EINVAL);
357
358	rc = iommu_copy_struct_from_user(&user_cfg, user_data,
359					 IOMMU_HWPT_DATA_SELFTEST, iotlb);
360	if (rc)
361		return ERR_PTR(rc);
362
363	return __mock_domain_alloc_nested(mock_parent, &user_cfg);
364}
365
366static void mock_domain_free(struct iommu_domain *domain)
367{
368	struct mock_iommu_domain *mock =
369		container_of(domain, struct mock_iommu_domain, domain);
370
371	WARN_ON(!xa_empty(&mock->pfns));
372	kfree(mock);
373}
374
375static int mock_domain_map_pages(struct iommu_domain *domain,
376				 unsigned long iova, phys_addr_t paddr,
377				 size_t pgsize, size_t pgcount, int prot,
378				 gfp_t gfp, size_t *mapped)
379{
380	struct mock_iommu_domain *mock =
381		container_of(domain, struct mock_iommu_domain, domain);
382	unsigned long flags = MOCK_PFN_START_IOVA;
383	unsigned long start_iova = iova;
384
385	/*
386	 * xarray does not reliably work with fault injection because it does a
387	 * retry allocation, so put our own failure point.
388	 */
389	if (iommufd_should_fail())
390		return -ENOENT;
391
392	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
393	WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
394	for (; pgcount; pgcount--) {
395		size_t cur;
396
397		for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
398			void *old;
399
400			if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
401				flags = MOCK_PFN_LAST_IOVA;
402			if (pgsize != MOCK_IO_PAGE_SIZE) {
403				flags |= MOCK_PFN_HUGE_IOVA;
404			}
405			old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
406				       xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
407						   flags),
408				       gfp);
409			if (xa_is_err(old)) {
410				for (; start_iova != iova;
411				     start_iova += MOCK_IO_PAGE_SIZE)
412					xa_erase(&mock->pfns,
413						 start_iova /
414							 MOCK_IO_PAGE_SIZE);
415				return xa_err(old);
416			}
417			WARN_ON(old);
418			iova += MOCK_IO_PAGE_SIZE;
419			paddr += MOCK_IO_PAGE_SIZE;
420			*mapped += MOCK_IO_PAGE_SIZE;
421			flags = 0;
422		}
423	}
424	return 0;
425}
426
427static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
428				      unsigned long iova, size_t pgsize,
429				      size_t pgcount,
430				      struct iommu_iotlb_gather *iotlb_gather)
431{
432	struct mock_iommu_domain *mock =
433		container_of(domain, struct mock_iommu_domain, domain);
434	bool first = true;
435	size_t ret = 0;
436	void *ent;
437
438	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
439	WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
440
441	for (; pgcount; pgcount--) {
442		size_t cur;
443
444		for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
445			ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
446
447			/*
448			 * iommufd generates unmaps that must be a strict
449			 * superset of the map's performend So every
450			 * starting/ending IOVA should have been an iova passed
451			 * to map.
452			 *
453			 * This simple logic doesn't work when the HUGE_PAGE is
454			 * turned on since the core code will automatically
455			 * switch between the two page sizes creating a break in
456			 * the unmap calls. The break can land in the middle of
457			 * contiguous IOVA.
458			 */
459			if (!(domain->pgsize_bitmap & MOCK_HUGE_PAGE_SIZE)) {
460				if (first) {
461					WARN_ON(ent && !(xa_to_value(ent) &
462							 MOCK_PFN_START_IOVA));
463					first = false;
464				}
465				if (pgcount == 1 &&
466				    cur + MOCK_IO_PAGE_SIZE == pgsize)
467					WARN_ON(ent && !(xa_to_value(ent) &
468							 MOCK_PFN_LAST_IOVA));
469			}
470
471			iova += MOCK_IO_PAGE_SIZE;
472			ret += MOCK_IO_PAGE_SIZE;
473		}
474	}
475	return ret;
476}
477
478static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
479					    dma_addr_t iova)
480{
481	struct mock_iommu_domain *mock =
482		container_of(domain, struct mock_iommu_domain, domain);
483	void *ent;
484
485	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
486	ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
487	WARN_ON(!ent);
488	return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
489}
490
491static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
492{
493	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
494
495	switch (cap) {
496	case IOMMU_CAP_CACHE_COHERENCY:
497		return true;
498	case IOMMU_CAP_DIRTY_TRACKING:
499		return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
500	default:
501		break;
502	}
503
504	return false;
505}
506
507static struct iommu_device mock_iommu_device = {
508};
509
510static struct iommu_device *mock_probe_device(struct device *dev)
511{
512	if (dev->bus != &iommufd_mock_bus_type.bus)
513		return ERR_PTR(-ENODEV);
514	return &mock_iommu_device;
515}
516
517static const struct iommu_ops mock_ops = {
518	/*
519	 * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
520	 * because it is zero.
521	 */
522	.default_domain = &mock_blocking_domain,
523	.blocked_domain = &mock_blocking_domain,
524	.owner = THIS_MODULE,
525	.pgsize_bitmap = MOCK_IO_PAGE_SIZE,
526	.hw_info = mock_domain_hw_info,
527	.domain_alloc_paging = mock_domain_alloc_paging,
528	.domain_alloc_user = mock_domain_alloc_user,
529	.capable = mock_domain_capable,
530	.device_group = generic_device_group,
531	.probe_device = mock_probe_device,
532	.default_domain_ops =
533		&(struct iommu_domain_ops){
534			.free = mock_domain_free,
535			.attach_dev = mock_domain_nop_attach,
536			.map_pages = mock_domain_map_pages,
537			.unmap_pages = mock_domain_unmap_pages,
538			.iova_to_phys = mock_domain_iova_to_phys,
539		},
540};
541
542static void mock_domain_free_nested(struct iommu_domain *domain)
543{
544	struct mock_iommu_domain_nested *mock_nested =
545		container_of(domain, struct mock_iommu_domain_nested, domain);
546
547	kfree(mock_nested);
548}
549
550static int
551mock_domain_cache_invalidate_user(struct iommu_domain *domain,
552				  struct iommu_user_data_array *array)
553{
554	struct mock_iommu_domain_nested *mock_nested =
555		container_of(domain, struct mock_iommu_domain_nested, domain);
556	struct iommu_hwpt_invalidate_selftest inv;
557	u32 processed = 0;
558	int i = 0, j;
559	int rc = 0;
560
561	if (array->type != IOMMU_HWPT_INVALIDATE_DATA_SELFTEST) {
562		rc = -EINVAL;
563		goto out;
564	}
565
566	for ( ; i < array->entry_num; i++) {
567		rc = iommu_copy_struct_from_user_array(&inv, array,
568						       IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
569						       i, iotlb_id);
570		if (rc)
571			break;
572
573		if (inv.flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
574			rc = -EOPNOTSUPP;
575			break;
576		}
577
578		if (inv.iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX) {
579			rc = -EINVAL;
580			break;
581		}
582
583		if (inv.flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
584			/* Invalidate all mock iotlb entries and ignore iotlb_id */
585			for (j = 0; j < MOCK_NESTED_DOMAIN_IOTLB_NUM; j++)
586				mock_nested->iotlb[j] = 0;
587		} else {
588			mock_nested->iotlb[inv.iotlb_id] = 0;
589		}
590
591		processed++;
592	}
593
594out:
595	array->entry_num = processed;
596	return rc;
597}
598
599static struct iommu_domain_ops domain_nested_ops = {
600	.free = mock_domain_free_nested,
601	.attach_dev = mock_domain_nop_attach,
602	.cache_invalidate_user = mock_domain_cache_invalidate_user,
603};
604
605static inline struct iommufd_hw_pagetable *
606__get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
607{
608	struct iommufd_object *obj;
609
610	obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
611	if (IS_ERR(obj))
612		return ERR_CAST(obj);
613	return container_of(obj, struct iommufd_hw_pagetable, obj);
614}
615
616static inline struct iommufd_hw_pagetable *
617get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
618		 struct mock_iommu_domain **mock)
619{
620	struct iommufd_hw_pagetable *hwpt;
621
622	hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
623	if (IS_ERR(hwpt))
624		return hwpt;
625	if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
626	    hwpt->domain->ops != mock_ops.default_domain_ops) {
627		iommufd_put_object(ucmd->ictx, &hwpt->obj);
628		return ERR_PTR(-EINVAL);
629	}
630	*mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
631	return hwpt;
632}
633
634static inline struct iommufd_hw_pagetable *
635get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
636			struct mock_iommu_domain_nested **mock_nested)
637{
638	struct iommufd_hw_pagetable *hwpt;
639
640	hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
641	if (IS_ERR(hwpt))
642		return hwpt;
643	if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
644	    hwpt->domain->ops != &domain_nested_ops) {
645		iommufd_put_object(ucmd->ictx, &hwpt->obj);
646		return ERR_PTR(-EINVAL);
647	}
648	*mock_nested = container_of(hwpt->domain,
649				    struct mock_iommu_domain_nested, domain);
650	return hwpt;
651}
652
653static void mock_dev_release(struct device *dev)
654{
655	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
656
657	ida_free(&mock_dev_ida, mdev->id);
658	kfree(mdev);
659}
660
661static struct mock_dev *mock_dev_create(unsigned long dev_flags)
662{
663	struct mock_dev *mdev;
664	int rc;
665
666	if (dev_flags &
667	    ~(MOCK_FLAGS_DEVICE_NO_DIRTY | MOCK_FLAGS_DEVICE_HUGE_IOVA))
668		return ERR_PTR(-EINVAL);
669
670	mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
671	if (!mdev)
672		return ERR_PTR(-ENOMEM);
673
674	device_initialize(&mdev->dev);
675	mdev->flags = dev_flags;
676	mdev->dev.release = mock_dev_release;
677	mdev->dev.bus = &iommufd_mock_bus_type.bus;
678
679	rc = ida_alloc(&mock_dev_ida, GFP_KERNEL);
680	if (rc < 0)
681		goto err_put;
682	mdev->id = rc;
683
684	rc = dev_set_name(&mdev->dev, "iommufd_mock%u", mdev->id);
685	if (rc)
686		goto err_put;
687
688	rc = device_add(&mdev->dev);
689	if (rc)
690		goto err_put;
691	return mdev;
692
693err_put:
694	put_device(&mdev->dev);
695	return ERR_PTR(rc);
696}
697
698static void mock_dev_destroy(struct mock_dev *mdev)
699{
700	device_unregister(&mdev->dev);
701}
702
703bool iommufd_selftest_is_mock_dev(struct device *dev)
704{
705	return dev->release == mock_dev_release;
706}
707
708/* Create an hw_pagetable with the mock domain so we can test the domain ops */
709static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
710				    struct iommu_test_cmd *cmd)
711{
712	struct iommufd_device *idev;
713	struct selftest_obj *sobj;
714	u32 pt_id = cmd->id;
715	u32 dev_flags = 0;
716	u32 idev_id;
717	int rc;
718
719	sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
720	if (IS_ERR(sobj))
721		return PTR_ERR(sobj);
722
723	sobj->idev.ictx = ucmd->ictx;
724	sobj->type = TYPE_IDEV;
725
726	if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
727		dev_flags = cmd->mock_domain_flags.dev_flags;
728
729	sobj->idev.mock_dev = mock_dev_create(dev_flags);
730	if (IS_ERR(sobj->idev.mock_dev)) {
731		rc = PTR_ERR(sobj->idev.mock_dev);
732		goto out_sobj;
733	}
734
735	idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
736				   &idev_id);
737	if (IS_ERR(idev)) {
738		rc = PTR_ERR(idev);
739		goto out_mdev;
740	}
741	sobj->idev.idev = idev;
742
743	rc = iommufd_device_attach(idev, &pt_id);
744	if (rc)
745		goto out_unbind;
746
747	/* Userspace must destroy the device_id to destroy the object */
748	cmd->mock_domain.out_hwpt_id = pt_id;
749	cmd->mock_domain.out_stdev_id = sobj->obj.id;
750	cmd->mock_domain.out_idev_id = idev_id;
751	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
752	if (rc)
753		goto out_detach;
754	iommufd_object_finalize(ucmd->ictx, &sobj->obj);
755	return 0;
756
757out_detach:
758	iommufd_device_detach(idev);
759out_unbind:
760	iommufd_device_unbind(idev);
761out_mdev:
762	mock_dev_destroy(sobj->idev.mock_dev);
763out_sobj:
764	iommufd_object_abort(ucmd->ictx, &sobj->obj);
765	return rc;
766}
767
768/* Replace the mock domain with a manually allocated hw_pagetable */
769static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
770					    unsigned int device_id, u32 pt_id,
771					    struct iommu_test_cmd *cmd)
772{
773	struct iommufd_object *dev_obj;
774	struct selftest_obj *sobj;
775	int rc;
776
777	/*
778	 * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
779	 * it doesn't race with detach, which is not allowed.
780	 */
781	dev_obj =
782		iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
783	if (IS_ERR(dev_obj))
784		return PTR_ERR(dev_obj);
785
786	sobj = container_of(dev_obj, struct selftest_obj, obj);
787	if (sobj->type != TYPE_IDEV) {
788		rc = -EINVAL;
789		goto out_dev_obj;
790	}
791
792	rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
793	if (rc)
794		goto out_dev_obj;
795
796	cmd->mock_domain_replace.pt_id = pt_id;
797	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
798
799out_dev_obj:
800	iommufd_put_object(ucmd->ictx, dev_obj);
801	return rc;
802}
803
804/* Add an additional reserved IOVA to the IOAS */
805static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
806				     unsigned int mockpt_id,
807				     unsigned long start, size_t length)
808{
809	struct iommufd_ioas *ioas;
810	int rc;
811
812	ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
813	if (IS_ERR(ioas))
814		return PTR_ERR(ioas);
815	down_write(&ioas->iopt.iova_rwsem);
816	rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
817	up_write(&ioas->iopt.iova_rwsem);
818	iommufd_put_object(ucmd->ictx, &ioas->obj);
819	return rc;
820}
821
822/* Check that every pfn under each iova matches the pfn under a user VA */
823static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
824				    unsigned int mockpt_id, unsigned long iova,
825				    size_t length, void __user *uptr)
826{
827	struct iommufd_hw_pagetable *hwpt;
828	struct mock_iommu_domain *mock;
829	uintptr_t end;
830	int rc;
831
832	if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
833	    (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
834	    check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
835		return -EINVAL;
836
837	hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
838	if (IS_ERR(hwpt))
839		return PTR_ERR(hwpt);
840
841	for (; length; length -= MOCK_IO_PAGE_SIZE) {
842		struct page *pages[1];
843		unsigned long pfn;
844		long npages;
845		void *ent;
846
847		npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
848					     pages);
849		if (npages < 0) {
850			rc = npages;
851			goto out_put;
852		}
853		if (WARN_ON(npages != 1)) {
854			rc = -EFAULT;
855			goto out_put;
856		}
857		pfn = page_to_pfn(pages[0]);
858		put_page(pages[0]);
859
860		ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
861		if (!ent ||
862		    (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
863			    pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
864			rc = -EINVAL;
865			goto out_put;
866		}
867		iova += MOCK_IO_PAGE_SIZE;
868		uptr += MOCK_IO_PAGE_SIZE;
869	}
870	rc = 0;
871
872out_put:
873	iommufd_put_object(ucmd->ictx, &hwpt->obj);
874	return rc;
875}
876
877/* Check that the page ref count matches, to look for missing pin/unpins */
878static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
879				      void __user *uptr, size_t length,
880				      unsigned int refs)
881{
882	uintptr_t end;
883
884	if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
885	    check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
886		return -EINVAL;
887
888	for (; length; length -= PAGE_SIZE) {
889		struct page *pages[1];
890		long npages;
891
892		npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
893		if (npages < 0)
894			return npages;
895		if (WARN_ON(npages != 1))
896			return -EFAULT;
897		if (!PageCompound(pages[0])) {
898			unsigned int count;
899
900			count = page_ref_count(pages[0]);
901			if (count / GUP_PIN_COUNTING_BIAS != refs) {
902				put_page(pages[0]);
903				return -EIO;
904			}
905		}
906		put_page(pages[0]);
907		uptr += PAGE_SIZE;
908	}
909	return 0;
910}
911
912static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
913				       u32 mockpt_id, unsigned int iotlb_id,
914				       u32 iotlb)
915{
916	struct mock_iommu_domain_nested *mock_nested;
917	struct iommufd_hw_pagetable *hwpt;
918	int rc = 0;
919
920	hwpt = get_md_pagetable_nested(ucmd, mockpt_id, &mock_nested);
921	if (IS_ERR(hwpt))
922		return PTR_ERR(hwpt);
923
924	mock_nested = container_of(hwpt->domain,
925				   struct mock_iommu_domain_nested, domain);
926
927	if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
928	    mock_nested->iotlb[iotlb_id] != iotlb)
929		rc = -EINVAL;
930	iommufd_put_object(ucmd->ictx, &hwpt->obj);
931	return rc;
932}
933
934struct selftest_access {
935	struct iommufd_access *access;
936	struct file *file;
937	struct mutex lock;
938	struct list_head items;
939	unsigned int next_id;
940	bool destroying;
941};
942
943struct selftest_access_item {
944	struct list_head items_elm;
945	unsigned long iova;
946	size_t length;
947	unsigned int id;
948};
949
950static const struct file_operations iommfd_test_staccess_fops;
951
952static struct selftest_access *iommufd_access_get(int fd)
953{
954	struct file *file;
955
956	file = fget(fd);
957	if (!file)
958		return ERR_PTR(-EBADFD);
959
960	if (file->f_op != &iommfd_test_staccess_fops) {
961		fput(file);
962		return ERR_PTR(-EBADFD);
963	}
964	return file->private_data;
965}
966
967static void iommufd_test_access_unmap(void *data, unsigned long iova,
968				      unsigned long length)
969{
970	unsigned long iova_last = iova + length - 1;
971	struct selftest_access *staccess = data;
972	struct selftest_access_item *item;
973	struct selftest_access_item *tmp;
974
975	mutex_lock(&staccess->lock);
976	list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
977		if (iova > item->iova + item->length - 1 ||
978		    iova_last < item->iova)
979			continue;
980		list_del(&item->items_elm);
981		iommufd_access_unpin_pages(staccess->access, item->iova,
982					   item->length);
983		kfree(item);
984	}
985	mutex_unlock(&staccess->lock);
986}
987
988static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
989					    unsigned int access_id,
990					    unsigned int item_id)
991{
992	struct selftest_access_item *item;
993	struct selftest_access *staccess;
994
995	staccess = iommufd_access_get(access_id);
996	if (IS_ERR(staccess))
997		return PTR_ERR(staccess);
998
999	mutex_lock(&staccess->lock);
1000	list_for_each_entry(item, &staccess->items, items_elm) {
1001		if (item->id == item_id) {
1002			list_del(&item->items_elm);
1003			iommufd_access_unpin_pages(staccess->access, item->iova,
1004						   item->length);
1005			mutex_unlock(&staccess->lock);
1006			kfree(item);
1007			fput(staccess->file);
1008			return 0;
1009		}
1010	}
1011	mutex_unlock(&staccess->lock);
1012	fput(staccess->file);
1013	return -ENOENT;
1014}
1015
1016static int iommufd_test_staccess_release(struct inode *inode,
1017					 struct file *filep)
1018{
1019	struct selftest_access *staccess = filep->private_data;
1020
1021	if (staccess->access) {
1022		iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
1023		iommufd_access_destroy(staccess->access);
1024	}
1025	mutex_destroy(&staccess->lock);
1026	kfree(staccess);
1027	return 0;
1028}
1029
1030static const struct iommufd_access_ops selftest_access_ops_pin = {
1031	.needs_pin_pages = 1,
1032	.unmap = iommufd_test_access_unmap,
1033};
1034
1035static const struct iommufd_access_ops selftest_access_ops = {
1036	.unmap = iommufd_test_access_unmap,
1037};
1038
1039static const struct file_operations iommfd_test_staccess_fops = {
1040	.release = iommufd_test_staccess_release,
1041};
1042
1043static struct selftest_access *iommufd_test_alloc_access(void)
1044{
1045	struct selftest_access *staccess;
1046	struct file *filep;
1047
1048	staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
1049	if (!staccess)
1050		return ERR_PTR(-ENOMEM);
1051	INIT_LIST_HEAD(&staccess->items);
1052	mutex_init(&staccess->lock);
1053
1054	filep = anon_inode_getfile("[iommufd_test_staccess]",
1055				   &iommfd_test_staccess_fops, staccess,
1056				   O_RDWR);
1057	if (IS_ERR(filep)) {
1058		kfree(staccess);
1059		return ERR_CAST(filep);
1060	}
1061	staccess->file = filep;
1062	return staccess;
1063}
1064
1065static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
1066				      unsigned int ioas_id, unsigned int flags)
1067{
1068	struct iommu_test_cmd *cmd = ucmd->cmd;
1069	struct selftest_access *staccess;
1070	struct iommufd_access *access;
1071	u32 id;
1072	int fdno;
1073	int rc;
1074
1075	if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
1076		return -EOPNOTSUPP;
1077
1078	staccess = iommufd_test_alloc_access();
1079	if (IS_ERR(staccess))
1080		return PTR_ERR(staccess);
1081
1082	fdno = get_unused_fd_flags(O_CLOEXEC);
1083	if (fdno < 0) {
1084		rc = -ENOMEM;
1085		goto out_free_staccess;
1086	}
1087
1088	access = iommufd_access_create(
1089		ucmd->ictx,
1090		(flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
1091			&selftest_access_ops_pin :
1092			&selftest_access_ops,
1093		staccess, &id);
1094	if (IS_ERR(access)) {
1095		rc = PTR_ERR(access);
1096		goto out_put_fdno;
1097	}
1098	rc = iommufd_access_attach(access, ioas_id);
1099	if (rc)
1100		goto out_destroy;
1101	cmd->create_access.out_access_fd = fdno;
1102	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1103	if (rc)
1104		goto out_destroy;
1105
1106	staccess->access = access;
1107	fd_install(fdno, staccess->file);
1108	return 0;
1109
1110out_destroy:
1111	iommufd_access_destroy(access);
1112out_put_fdno:
1113	put_unused_fd(fdno);
1114out_free_staccess:
1115	fput(staccess->file);
1116	return rc;
1117}
1118
1119static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
1120					    unsigned int access_id,
1121					    unsigned int ioas_id)
1122{
1123	struct selftest_access *staccess;
1124	int rc;
1125
1126	staccess = iommufd_access_get(access_id);
1127	if (IS_ERR(staccess))
1128		return PTR_ERR(staccess);
1129
1130	rc = iommufd_access_replace(staccess->access, ioas_id);
1131	fput(staccess->file);
1132	return rc;
1133}
1134
1135/* Check that the pages in a page array match the pages in the user VA */
1136static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
1137				    size_t npages)
1138{
1139	for (; npages; npages--) {
1140		struct page *tmp_pages[1];
1141		long rc;
1142
1143		rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
1144		if (rc < 0)
1145			return rc;
1146		if (WARN_ON(rc != 1))
1147			return -EFAULT;
1148		put_page(tmp_pages[0]);
1149		if (tmp_pages[0] != *pages)
1150			return -EBADE;
1151		pages++;
1152		uptr += PAGE_SIZE;
1153	}
1154	return 0;
1155}
1156
1157static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
1158				     unsigned int access_id, unsigned long iova,
1159				     size_t length, void __user *uptr,
1160				     u32 flags)
1161{
1162	struct iommu_test_cmd *cmd = ucmd->cmd;
1163	struct selftest_access_item *item;
1164	struct selftest_access *staccess;
1165	struct page **pages;
1166	size_t npages;
1167	int rc;
1168
1169	/* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1170	if (length > 16*1024*1024)
1171		return -ENOMEM;
1172
1173	if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
1174		return -EOPNOTSUPP;
1175
1176	staccess = iommufd_access_get(access_id);
1177	if (IS_ERR(staccess))
1178		return PTR_ERR(staccess);
1179
1180	if (staccess->access->ops != &selftest_access_ops_pin) {
1181		rc = -EOPNOTSUPP;
1182		goto out_put;
1183	}
1184
1185	if (flags & MOCK_FLAGS_ACCESS_SYZ)
1186		iova = iommufd_test_syz_conv_iova(staccess->access,
1187					&cmd->access_pages.iova);
1188
1189	npages = (ALIGN(iova + length, PAGE_SIZE) -
1190		  ALIGN_DOWN(iova, PAGE_SIZE)) /
1191		 PAGE_SIZE;
1192	pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
1193	if (!pages) {
1194		rc = -ENOMEM;
1195		goto out_put;
1196	}
1197
1198	/*
1199	 * Drivers will need to think very carefully about this locking. The
1200	 * core code can do multiple unmaps instantaneously after
1201	 * iommufd_access_pin_pages() and *all* the unmaps must not return until
1202	 * the range is unpinned. This simple implementation puts a global lock
1203	 * around the pin, which may not suit drivers that want this to be a
1204	 * performance path. drivers that get this wrong will trigger WARN_ON
1205	 * races and cause EDEADLOCK failures to userspace.
1206	 */
1207	mutex_lock(&staccess->lock);
1208	rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
1209				      flags & MOCK_FLAGS_ACCESS_WRITE);
1210	if (rc)
1211		goto out_unlock;
1212
1213	/* For syzkaller allow uptr to be NULL to skip this check */
1214	if (uptr) {
1215		rc = iommufd_test_check_pages(
1216			uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
1217			npages);
1218		if (rc)
1219			goto out_unaccess;
1220	}
1221
1222	item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
1223	if (!item) {
1224		rc = -ENOMEM;
1225		goto out_unaccess;
1226	}
1227
1228	item->iova = iova;
1229	item->length = length;
1230	item->id = staccess->next_id++;
1231	list_add_tail(&item->items_elm, &staccess->items);
1232
1233	cmd->access_pages.out_access_pages_id = item->id;
1234	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1235	if (rc)
1236		goto out_free_item;
1237	goto out_unlock;
1238
1239out_free_item:
1240	list_del(&item->items_elm);
1241	kfree(item);
1242out_unaccess:
1243	iommufd_access_unpin_pages(staccess->access, iova, length);
1244out_unlock:
1245	mutex_unlock(&staccess->lock);
1246	kvfree(pages);
1247out_put:
1248	fput(staccess->file);
1249	return rc;
1250}
1251
1252static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
1253				  unsigned int access_id, unsigned long iova,
1254				  size_t length, void __user *ubuf,
1255				  unsigned int flags)
1256{
1257	struct iommu_test_cmd *cmd = ucmd->cmd;
1258	struct selftest_access *staccess;
1259	void *tmp;
1260	int rc;
1261
1262	/* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1263	if (length > 16*1024*1024)
1264		return -ENOMEM;
1265
1266	if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
1267		      MOCK_FLAGS_ACCESS_SYZ))
1268		return -EOPNOTSUPP;
1269
1270	staccess = iommufd_access_get(access_id);
1271	if (IS_ERR(staccess))
1272		return PTR_ERR(staccess);
1273
1274	tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
1275	if (!tmp) {
1276		rc = -ENOMEM;
1277		goto out_put;
1278	}
1279
1280	if (flags & MOCK_ACCESS_RW_WRITE) {
1281		if (copy_from_user(tmp, ubuf, length)) {
1282			rc = -EFAULT;
1283			goto out_free;
1284		}
1285	}
1286
1287	if (flags & MOCK_FLAGS_ACCESS_SYZ)
1288		iova = iommufd_test_syz_conv_iova(staccess->access,
1289				&cmd->access_rw.iova);
1290
1291	rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
1292	if (rc)
1293		goto out_free;
1294	if (!(flags & MOCK_ACCESS_RW_WRITE)) {
1295		if (copy_to_user(ubuf, tmp, length)) {
1296			rc = -EFAULT;
1297			goto out_free;
1298		}
1299	}
1300
1301out_free:
1302	kvfree(tmp);
1303out_put:
1304	fput(staccess->file);
1305	return rc;
1306}
1307static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
1308static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
1309	      __IOMMUFD_ACCESS_RW_SLOW_PATH);
1310
1311static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
1312			      unsigned long iova, size_t length,
1313			      unsigned long page_size, void __user *uptr,
1314			      u32 flags)
1315{
1316	unsigned long bitmap_size, i, max;
1317	struct iommu_test_cmd *cmd = ucmd->cmd;
1318	struct iommufd_hw_pagetable *hwpt;
1319	struct mock_iommu_domain *mock;
1320	int rc, count = 0;
1321	void *tmp;
1322
1323	if (!page_size || !length || iova % page_size || length % page_size ||
1324	    !uptr)
1325		return -EINVAL;
1326
1327	hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1328	if (IS_ERR(hwpt))
1329		return PTR_ERR(hwpt);
1330
1331	if (!(mock->flags & MOCK_DIRTY_TRACK)) {
1332		rc = -EINVAL;
1333		goto out_put;
1334	}
1335
1336	max = length / page_size;
1337	bitmap_size = max / BITS_PER_BYTE;
1338
1339	tmp = kvzalloc(bitmap_size, GFP_KERNEL_ACCOUNT);
1340	if (!tmp) {
1341		rc = -ENOMEM;
1342		goto out_put;
1343	}
1344
1345	if (copy_from_user(tmp, uptr, bitmap_size)) {
1346		rc = -EFAULT;
1347		goto out_free;
1348	}
1349
1350	for (i = 0; i < max; i++) {
1351		unsigned long cur = iova + i * page_size;
1352		void *ent, *old;
1353
1354		if (!test_bit(i, (unsigned long *)tmp))
1355			continue;
1356
1357		ent = xa_load(&mock->pfns, cur / page_size);
1358		if (ent) {
1359			unsigned long val;
1360
1361			val = xa_to_value(ent) | MOCK_PFN_DIRTY_IOVA;
1362			old = xa_store(&mock->pfns, cur / page_size,
1363				       xa_mk_value(val), GFP_KERNEL);
1364			WARN_ON_ONCE(ent != old);
1365			count++;
1366		}
1367	}
1368
1369	cmd->dirty.out_nr_dirty = count;
1370	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1371out_free:
1372	kvfree(tmp);
1373out_put:
1374	iommufd_put_object(ucmd->ictx, &hwpt->obj);
1375	return rc;
1376}
1377
1378void iommufd_selftest_destroy(struct iommufd_object *obj)
1379{
1380	struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
1381
1382	switch (sobj->type) {
1383	case TYPE_IDEV:
1384		iommufd_device_detach(sobj->idev.idev);
1385		iommufd_device_unbind(sobj->idev.idev);
1386		mock_dev_destroy(sobj->idev.mock_dev);
1387		break;
1388	}
1389}
1390
1391int iommufd_test(struct iommufd_ucmd *ucmd)
1392{
1393	struct iommu_test_cmd *cmd = ucmd->cmd;
1394
1395	switch (cmd->op) {
1396	case IOMMU_TEST_OP_ADD_RESERVED:
1397		return iommufd_test_add_reserved(ucmd, cmd->id,
1398						 cmd->add_reserved.start,
1399						 cmd->add_reserved.length);
1400	case IOMMU_TEST_OP_MOCK_DOMAIN:
1401	case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
1402		return iommufd_test_mock_domain(ucmd, cmd);
1403	case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1404		return iommufd_test_mock_domain_replace(
1405			ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1406	case IOMMU_TEST_OP_MD_CHECK_MAP:
1407		return iommufd_test_md_check_pa(
1408			ucmd, cmd->id, cmd->check_map.iova,
1409			cmd->check_map.length,
1410			u64_to_user_ptr(cmd->check_map.uptr));
1411	case IOMMU_TEST_OP_MD_CHECK_REFS:
1412		return iommufd_test_md_check_refs(
1413			ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1414			cmd->check_refs.length, cmd->check_refs.refs);
1415	case IOMMU_TEST_OP_MD_CHECK_IOTLB:
1416		return iommufd_test_md_check_iotlb(ucmd, cmd->id,
1417						   cmd->check_iotlb.id,
1418						   cmd->check_iotlb.iotlb);
1419	case IOMMU_TEST_OP_CREATE_ACCESS:
1420		return iommufd_test_create_access(ucmd, cmd->id,
1421						  cmd->create_access.flags);
1422	case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1423		return iommufd_test_access_replace_ioas(
1424			ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1425	case IOMMU_TEST_OP_ACCESS_PAGES:
1426		return iommufd_test_access_pages(
1427			ucmd, cmd->id, cmd->access_pages.iova,
1428			cmd->access_pages.length,
1429			u64_to_user_ptr(cmd->access_pages.uptr),
1430			cmd->access_pages.flags);
1431	case IOMMU_TEST_OP_ACCESS_RW:
1432		return iommufd_test_access_rw(
1433			ucmd, cmd->id, cmd->access_rw.iova,
1434			cmd->access_rw.length,
1435			u64_to_user_ptr(cmd->access_rw.uptr),
1436			cmd->access_rw.flags);
1437	case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1438		return iommufd_test_access_item_destroy(
1439			ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1440	case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1441		/* Protect _batch_init(), can not be less than elmsz */
1442		if (cmd->memory_limit.limit <
1443		    sizeof(unsigned long) + sizeof(u32))
1444			return -EINVAL;
1445		iommufd_test_memory_limit = cmd->memory_limit.limit;
1446		return 0;
1447	case IOMMU_TEST_OP_DIRTY:
1448		return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
1449					  cmd->dirty.length,
1450					  cmd->dirty.page_size,
1451					  u64_to_user_ptr(cmd->dirty.uptr),
1452					  cmd->dirty.flags);
1453	default:
1454		return -EOPNOTSUPP;
1455	}
1456}
1457
1458bool iommufd_should_fail(void)
1459{
1460	return should_fail(&fail_iommufd, 1);
1461}
1462
1463int __init iommufd_test_init(void)
1464{
1465	struct platform_device_info pdevinfo = {
1466		.name = "iommufd_selftest_iommu",
1467	};
1468	int rc;
1469
1470	dbgfs_root =
1471		fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1472
1473	selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1474	if (IS_ERR(selftest_iommu_dev)) {
1475		rc = PTR_ERR(selftest_iommu_dev);
1476		goto err_dbgfs;
1477	}
1478
1479	rc = bus_register(&iommufd_mock_bus_type.bus);
1480	if (rc)
1481		goto err_platform;
1482
1483	rc = iommu_device_sysfs_add(&mock_iommu_device,
1484				    &selftest_iommu_dev->dev, NULL, "%s",
1485				    dev_name(&selftest_iommu_dev->dev));
1486	if (rc)
1487		goto err_bus;
1488
1489	rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1490				  &iommufd_mock_bus_type.bus,
1491				  &iommufd_mock_bus_type.nb);
1492	if (rc)
1493		goto err_sysfs;
1494	return 0;
1495
1496err_sysfs:
1497	iommu_device_sysfs_remove(&mock_iommu_device);
1498err_bus:
1499	bus_unregister(&iommufd_mock_bus_type.bus);
1500err_platform:
1501	platform_device_unregister(selftest_iommu_dev);
1502err_dbgfs:
1503	debugfs_remove_recursive(dbgfs_root);
1504	return rc;
1505}
1506
1507void iommufd_test_exit(void)
1508{
1509	iommu_device_sysfs_remove(&mock_iommu_device);
1510	iommu_device_unregister_bus(&mock_iommu_device,
1511				    &iommufd_mock_bus_type.bus,
1512				    &iommufd_mock_bus_type.nb);
1513	bus_unregister(&iommufd_mock_bus_type.bus);
1514	platform_device_unregister(selftest_iommu_dev);
1515	debugfs_remove_recursive(dbgfs_root);
1516}
1517