1// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2/*
3 * Copyright(c) 2016 Intel Corporation.
4 */
5
6#include <linux/slab.h>
7#include <linux/sched.h>
8#include <linux/rculist.h>
9#include <rdma/rdma_vt.h>
10#include <rdma/rdmavt_qp.h>
11
12#include "mcast.h"
13
14/**
15 * rvt_driver_mcast_init - init resources for multicast
16 * @rdi: rvt dev struct
17 *
18 * This is per device that registers with rdmavt
19 */
20void rvt_driver_mcast_init(struct rvt_dev_info *rdi)
21{
22	/*
23	 * Anything that needs setup for multicast on a per driver or per rdi
24	 * basis should be done in here.
25	 */
26	spin_lock_init(&rdi->n_mcast_grps_lock);
27}
28
29/**
30 * rvt_mcast_qp_alloc - alloc a struct to link a QP to mcast GID struct
31 * @qp: the QP to link
32 */
33static struct rvt_mcast_qp *rvt_mcast_qp_alloc(struct rvt_qp *qp)
34{
35	struct rvt_mcast_qp *mqp;
36
37	mqp = kmalloc(sizeof(*mqp), GFP_KERNEL);
38	if (!mqp)
39		goto bail;
40
41	mqp->qp = qp;
42	rvt_get_qp(qp);
43
44bail:
45	return mqp;
46}
47
48static void rvt_mcast_qp_free(struct rvt_mcast_qp *mqp)
49{
50	struct rvt_qp *qp = mqp->qp;
51
52	/* Notify hfi1_destroy_qp() if it is waiting. */
53	rvt_put_qp(qp);
54
55	kfree(mqp);
56}
57
58/**
59 * rvt_mcast_alloc - allocate the multicast GID structure
60 * @mgid: the multicast GID
61 * @lid: the muilticast LID (host order)
62 *
63 * A list of QPs will be attached to this structure.
64 */
65static struct rvt_mcast *rvt_mcast_alloc(union ib_gid *mgid, u16 lid)
66{
67	struct rvt_mcast *mcast;
68
69	mcast = kzalloc(sizeof(*mcast), GFP_KERNEL);
70	if (!mcast)
71		goto bail;
72
73	mcast->mcast_addr.mgid = *mgid;
74	mcast->mcast_addr.lid = lid;
75
76	INIT_LIST_HEAD(&mcast->qp_list);
77	init_waitqueue_head(&mcast->wait);
78	atomic_set(&mcast->refcount, 0);
79
80bail:
81	return mcast;
82}
83
84static void rvt_mcast_free(struct rvt_mcast *mcast)
85{
86	struct rvt_mcast_qp *p, *tmp;
87
88	list_for_each_entry_safe(p, tmp, &mcast->qp_list, list)
89		rvt_mcast_qp_free(p);
90
91	kfree(mcast);
92}
93
94/**
95 * rvt_mcast_find - search the global table for the given multicast GID/LID
96 * NOTE: It is valid to have 1 MLID with multiple MGIDs.  It is not valid
97 * to have 1 MGID with multiple MLIDs.
98 * @ibp: the IB port structure
99 * @mgid: the multicast GID to search for
100 * @lid: the multicast LID portion of the multicast address (host order)
101 *
102 * The caller is responsible for decrementing the reference count if found.
103 *
104 * Return: NULL if not found.
105 */
106struct rvt_mcast *rvt_mcast_find(struct rvt_ibport *ibp, union ib_gid *mgid,
107				 u16 lid)
108{
109	struct rb_node *n;
110	unsigned long flags;
111	struct rvt_mcast *found = NULL;
112
113	spin_lock_irqsave(&ibp->lock, flags);
114	n = ibp->mcast_tree.rb_node;
115	while (n) {
116		int ret;
117		struct rvt_mcast *mcast;
118
119		mcast = rb_entry(n, struct rvt_mcast, rb_node);
120
121		ret = memcmp(mgid->raw, mcast->mcast_addr.mgid.raw,
122			     sizeof(*mgid));
123		if (ret < 0) {
124			n = n->rb_left;
125		} else if (ret > 0) {
126			n = n->rb_right;
127		} else {
128			/* MGID/MLID must match */
129			if (mcast->mcast_addr.lid == lid) {
130				atomic_inc(&mcast->refcount);
131				found = mcast;
132			}
133			break;
134		}
135	}
136	spin_unlock_irqrestore(&ibp->lock, flags);
137	return found;
138}
139EXPORT_SYMBOL(rvt_mcast_find);
140
141/*
142 * rvt_mcast_add - insert mcast GID into table and attach QP struct
143 * @mcast: the mcast GID table
144 * @mqp: the QP to attach
145 *
146 * Return: zero if both were added.  Return EEXIST if the GID was already in
147 * the table but the QP was added.  Return ESRCH if the QP was already
148 * attached and neither structure was added. Return EINVAL if the MGID was
149 * found, but the MLID did NOT match.
150 */
151static int rvt_mcast_add(struct rvt_dev_info *rdi, struct rvt_ibport *ibp,
152			 struct rvt_mcast *mcast, struct rvt_mcast_qp *mqp)
153{
154	struct rb_node **n = &ibp->mcast_tree.rb_node;
155	struct rb_node *pn = NULL;
156	int ret;
157
158	spin_lock_irq(&ibp->lock);
159
160	while (*n) {
161		struct rvt_mcast *tmcast;
162		struct rvt_mcast_qp *p;
163
164		pn = *n;
165		tmcast = rb_entry(pn, struct rvt_mcast, rb_node);
166
167		ret = memcmp(mcast->mcast_addr.mgid.raw,
168			     tmcast->mcast_addr.mgid.raw,
169			     sizeof(mcast->mcast_addr.mgid));
170		if (ret < 0) {
171			n = &pn->rb_left;
172			continue;
173		}
174		if (ret > 0) {
175			n = &pn->rb_right;
176			continue;
177		}
178
179		if (tmcast->mcast_addr.lid != mcast->mcast_addr.lid) {
180			ret = EINVAL;
181			goto bail;
182		}
183
184		/* Search the QP list to see if this is already there. */
185		list_for_each_entry_rcu(p, &tmcast->qp_list, list) {
186			if (p->qp == mqp->qp) {
187				ret = ESRCH;
188				goto bail;
189			}
190		}
191		if (tmcast->n_attached ==
192		    rdi->dparms.props.max_mcast_qp_attach) {
193			ret = ENOMEM;
194			goto bail;
195		}
196
197		tmcast->n_attached++;
198
199		list_add_tail_rcu(&mqp->list, &tmcast->qp_list);
200		ret = EEXIST;
201		goto bail;
202	}
203
204	spin_lock(&rdi->n_mcast_grps_lock);
205	if (rdi->n_mcast_grps_allocated == rdi->dparms.props.max_mcast_grp) {
206		spin_unlock(&rdi->n_mcast_grps_lock);
207		ret = ENOMEM;
208		goto bail;
209	}
210
211	rdi->n_mcast_grps_allocated++;
212	spin_unlock(&rdi->n_mcast_grps_lock);
213
214	mcast->n_attached++;
215
216	list_add_tail_rcu(&mqp->list, &mcast->qp_list);
217
218	atomic_inc(&mcast->refcount);
219	rb_link_node(&mcast->rb_node, pn, n);
220	rb_insert_color(&mcast->rb_node, &ibp->mcast_tree);
221
222	ret = 0;
223
224bail:
225	spin_unlock_irq(&ibp->lock);
226
227	return ret;
228}
229
230/**
231 * rvt_attach_mcast - attach a qp to a multicast group
232 * @ibqp: Infiniband qp
233 * @gid: multicast guid
234 * @lid: multicast lid
235 *
236 * Return: 0 on success
237 */
238int rvt_attach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
239{
240	struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
241	struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
242	struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
243	struct rvt_mcast *mcast;
244	struct rvt_mcast_qp *mqp;
245	int ret = -ENOMEM;
246
247	if (ibqp->qp_num <= 1 || qp->state == IB_QPS_RESET)
248		return -EINVAL;
249
250	/*
251	 * Allocate data structures since its better to do this outside of
252	 * spin locks and it will most likely be needed.
253	 */
254	mcast = rvt_mcast_alloc(gid, lid);
255	if (!mcast)
256		return -ENOMEM;
257
258	mqp = rvt_mcast_qp_alloc(qp);
259	if (!mqp)
260		goto bail_mcast;
261
262	switch (rvt_mcast_add(rdi, ibp, mcast, mqp)) {
263	case ESRCH:
264		/* Neither was used: OK to attach the same QP twice. */
265		ret = 0;
266		goto bail_mqp;
267	case EEXIST: /* The mcast wasn't used */
268		ret = 0;
269		goto bail_mcast;
270	case ENOMEM:
271		/* Exceeded the maximum number of mcast groups. */
272		ret = -ENOMEM;
273		goto bail_mqp;
274	case EINVAL:
275		/* Invalid MGID/MLID pair */
276		ret = -EINVAL;
277		goto bail_mqp;
278	default:
279		break;
280	}
281
282	return 0;
283
284bail_mqp:
285	rvt_mcast_qp_free(mqp);
286
287bail_mcast:
288	rvt_mcast_free(mcast);
289
290	return ret;
291}
292
293/**
294 * rvt_detach_mcast - remove a qp from a multicast group
295 * @ibqp: Infiniband qp
296 * @gid: multicast guid
297 * @lid: multicast lid
298 *
299 * Return: 0 on success
300 */
301int rvt_detach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
302{
303	struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
304	struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
305	struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
306	struct rvt_mcast *mcast = NULL;
307	struct rvt_mcast_qp *p, *tmp, *delp = NULL;
308	struct rb_node *n;
309	int last = 0;
310	int ret = 0;
311
312	if (ibqp->qp_num <= 1)
313		return -EINVAL;
314
315	spin_lock_irq(&ibp->lock);
316
317	/* Find the GID in the mcast table. */
318	n = ibp->mcast_tree.rb_node;
319	while (1) {
320		if (!n) {
321			spin_unlock_irq(&ibp->lock);
322			return -EINVAL;
323		}
324
325		mcast = rb_entry(n, struct rvt_mcast, rb_node);
326		ret = memcmp(gid->raw, mcast->mcast_addr.mgid.raw,
327			     sizeof(*gid));
328		if (ret < 0) {
329			n = n->rb_left;
330		} else if (ret > 0) {
331			n = n->rb_right;
332		} else {
333			/* MGID/MLID must match */
334			if (mcast->mcast_addr.lid != lid) {
335				spin_unlock_irq(&ibp->lock);
336				return -EINVAL;
337			}
338			break;
339		}
340	}
341
342	/* Search the QP list. */
343	list_for_each_entry_safe(p, tmp, &mcast->qp_list, list) {
344		if (p->qp != qp)
345			continue;
346		/*
347		 * We found it, so remove it, but don't poison the forward
348		 * link until we are sure there are no list walkers.
349		 */
350		list_del_rcu(&p->list);
351		mcast->n_attached--;
352		delp = p;
353
354		/* If this was the last attached QP, remove the GID too. */
355		if (list_empty(&mcast->qp_list)) {
356			rb_erase(&mcast->rb_node, &ibp->mcast_tree);
357			last = 1;
358		}
359		break;
360	}
361
362	spin_unlock_irq(&ibp->lock);
363	/* QP not attached */
364	if (!delp)
365		return -EINVAL;
366
367	/*
368	 * Wait for any list walkers to finish before freeing the
369	 * list element.
370	 */
371	wait_event(mcast->wait, atomic_read(&mcast->refcount) <= 1);
372	rvt_mcast_qp_free(delp);
373
374	if (last) {
375		atomic_dec(&mcast->refcount);
376		wait_event(mcast->wait, !atomic_read(&mcast->refcount));
377		rvt_mcast_free(mcast);
378		spin_lock_irq(&rdi->n_mcast_grps_lock);
379		rdi->n_mcast_grps_allocated--;
380		spin_unlock_irq(&rdi->n_mcast_grps_lock);
381	}
382
383	return 0;
384}
385
386/**
387 * rvt_mcast_tree_empty - determine if any qps are attached to any mcast group
388 * @rdi: rvt dev struct
389 *
390 * Return: in use count
391 */
392int rvt_mcast_tree_empty(struct rvt_dev_info *rdi)
393{
394	int i;
395	int in_use = 0;
396
397	for (i = 0; i < rdi->dparms.nports; i++)
398		if (rdi->ports[i]->mcast_tree.rb_node)
399			in_use++;
400	return in_use;
401}
402