1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2024 Advanced Micro Devices, Inc.
4 */
5
6#define pr_fmt(fmt)     "AMD-Vi: " fmt
7#define dev_fmt(fmt)    pr_fmt(fmt)
8
9#include <linux/iommu.h>
10#include <linux/mm_types.h>
11
12#include "amd_iommu.h"
13
14static inline bool is_pasid_enabled(struct iommu_dev_data *dev_data)
15{
16	if (dev_data->pasid_enabled && dev_data->max_pasids &&
17	    dev_data->gcr3_info.gcr3_tbl != NULL)
18		return true;
19
20	return false;
21}
22
23static inline bool is_pasid_valid(struct iommu_dev_data *dev_data,
24				  ioasid_t pasid)
25{
26	if (pasid > 0 && pasid < dev_data->max_pasids)
27		return true;
28
29	return false;
30}
31
32static void remove_dev_pasid(struct pdom_dev_data *pdom_dev_data)
33{
34	/* Update GCR3 table and flush IOTLB */
35	amd_iommu_clear_gcr3(pdom_dev_data->dev_data, pdom_dev_data->pasid);
36
37	list_del(&pdom_dev_data->list);
38	kfree(pdom_dev_data);
39}
40
41/* Clear PASID from device GCR3 table and remove pdom_dev_data from list */
42static void remove_pdom_dev_pasid(struct protection_domain *pdom,
43				  struct device *dev, ioasid_t pasid)
44{
45	struct pdom_dev_data *pdom_dev_data;
46	struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
47
48	lockdep_assert_held(&pdom->lock);
49
50	for_each_pdom_dev_data(pdom_dev_data, pdom) {
51		if (pdom_dev_data->dev_data == dev_data &&
52		    pdom_dev_data->pasid == pasid) {
53			remove_dev_pasid(pdom_dev_data);
54			break;
55		}
56	}
57}
58
59static void sva_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
60				    struct mm_struct *mm,
61				    unsigned long start, unsigned long end)
62{
63	struct pdom_dev_data *pdom_dev_data;
64	struct protection_domain *sva_pdom;
65	unsigned long flags;
66
67	sva_pdom = container_of(mn, struct protection_domain, mn);
68
69	spin_lock_irqsave(&sva_pdom->lock, flags);
70
71	for_each_pdom_dev_data(pdom_dev_data, sva_pdom) {
72		amd_iommu_dev_flush_pasid_pages(pdom_dev_data->dev_data,
73						pdom_dev_data->pasid,
74						start, end - start);
75	}
76
77	spin_unlock_irqrestore(&sva_pdom->lock, flags);
78}
79
80static void sva_mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
81{
82	struct pdom_dev_data *pdom_dev_data, *next;
83	struct protection_domain *sva_pdom;
84	unsigned long flags;
85
86	sva_pdom = container_of(mn, struct protection_domain, mn);
87
88	spin_lock_irqsave(&sva_pdom->lock, flags);
89
90	/* Assume dev_data_list contains same PASID with different devices */
91	for_each_pdom_dev_data_safe(pdom_dev_data, next, sva_pdom)
92		remove_dev_pasid(pdom_dev_data);
93
94	spin_unlock_irqrestore(&sva_pdom->lock, flags);
95}
96
97static const struct mmu_notifier_ops sva_mn = {
98	.arch_invalidate_secondary_tlbs = sva_arch_invalidate_secondary_tlbs,
99	.release = sva_mn_release,
100};
101
102int iommu_sva_set_dev_pasid(struct iommu_domain *domain,
103			    struct device *dev, ioasid_t pasid)
104{
105	struct pdom_dev_data *pdom_dev_data;
106	struct protection_domain *sva_pdom = to_pdomain(domain);
107	struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
108	unsigned long flags;
109	int ret = -EINVAL;
110
111	/* PASID zero is used for requests from the I/O device without PASID */
112	if (!is_pasid_valid(dev_data, pasid))
113		return ret;
114
115	/* Make sure PASID is enabled */
116	if (!is_pasid_enabled(dev_data))
117		return ret;
118
119	/* Add PASID to protection domain pasid list */
120	pdom_dev_data = kzalloc(sizeof(*pdom_dev_data), GFP_KERNEL);
121	if (pdom_dev_data == NULL)
122		return ret;
123
124	pdom_dev_data->pasid = pasid;
125	pdom_dev_data->dev_data = dev_data;
126
127	spin_lock_irqsave(&sva_pdom->lock, flags);
128
129	/* Setup GCR3 table */
130	ret = amd_iommu_set_gcr3(dev_data, pasid,
131				 iommu_virt_to_phys(domain->mm->pgd));
132	if (ret) {
133		kfree(pdom_dev_data);
134		goto out_unlock;
135	}
136
137	list_add(&pdom_dev_data->list, &sva_pdom->dev_data_list);
138
139out_unlock:
140	spin_unlock_irqrestore(&sva_pdom->lock, flags);
141	return ret;
142}
143
144void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
145				struct iommu_domain *domain)
146{
147	struct protection_domain *sva_pdom;
148	unsigned long flags;
149
150	if (!is_pasid_valid(dev_iommu_priv_get(dev), pasid))
151		return;
152
153	sva_pdom = to_pdomain(domain);
154
155	spin_lock_irqsave(&sva_pdom->lock, flags);
156
157	/* Remove PASID from dev_data_list */
158	remove_pdom_dev_pasid(sva_pdom, dev, pasid);
159
160	spin_unlock_irqrestore(&sva_pdom->lock, flags);
161}
162
163static void iommu_sva_domain_free(struct iommu_domain *domain)
164{
165	struct protection_domain *sva_pdom = to_pdomain(domain);
166
167	if (sva_pdom->mn.ops)
168		mmu_notifier_unregister(&sva_pdom->mn, domain->mm);
169
170	amd_iommu_domain_free(domain);
171}
172
173static const struct iommu_domain_ops amd_sva_domain_ops = {
174	.set_dev_pasid = iommu_sva_set_dev_pasid,
175	.free	       = iommu_sva_domain_free
176};
177
178struct iommu_domain *amd_iommu_domain_alloc_sva(struct device *dev,
179						struct mm_struct *mm)
180{
181	struct protection_domain *pdom;
182	int ret;
183
184	pdom = protection_domain_alloc(IOMMU_DOMAIN_SVA);
185	if (!pdom)
186		return ERR_PTR(-ENOMEM);
187
188	pdom->domain.ops = &amd_sva_domain_ops;
189	pdom->mn.ops = &sva_mn;
190
191	ret = mmu_notifier_register(&pdom->mn, mm);
192	if (ret) {
193		protection_domain_free(pdom);
194		return ERR_PTR(ret);
195	}
196
197	return &pdom->domain;
198}
199