ib_multicast.c revision 331769
1/*
2 * Copyright (c) 2006 Intel Corporation.  All rights reserved.
3 *
4 * This software is available to you under a choice of one of two
5 * licenses.  You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * OpenIB.org BSD license below:
9 *
10 *     Redistribution and use in source and binary forms, with or
11 *     without modification, are permitted provided that the following
12 *     conditions are met:
13 *
14 *      - Redistributions of source code must retain the above
15 *        copyright notice, this list of conditions and the following
16 *        disclaimer.
17 *
18 *      - Redistributions in binary form must reproduce the above
19 *        copyright notice, this list of conditions and the following
20 *        disclaimer in the documentation and/or other materials
21 *        provided with the distribution.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 * SOFTWARE.
31 */
32
33#define	LINUXKPI_PARAM_PREFIX ibcore_
34
35#include <linux/completion.h>
36#include <linux/dma-mapping.h>
37#include <linux/err.h>
38#include <linux/interrupt.h>
39#include <linux/slab.h>
40#include <linux/bitops.h>
41#include <linux/random.h>
42#include <linux/rbtree.h>
43
44#include <rdma/ib_cache.h>
45#include "sa.h"
46
47static void mcast_add_one(struct ib_device *device);
48static void mcast_remove_one(struct ib_device *device, void *client_data);
49
50static struct ib_client mcast_client = {
51	.name   = "ib_multicast",
52	.add    = mcast_add_one,
53	.remove = mcast_remove_one
54};
55
56static struct ib_sa_client	sa_client;
57static struct workqueue_struct	*mcast_wq;
58static union ib_gid mgid0;
59
60struct mcast_device;
61
62struct mcast_port {
63	struct mcast_device	*dev;
64	spinlock_t		lock;
65	struct rb_root		table;
66	atomic_t		refcount;
67	struct completion	comp;
68	u8			port_num;
69};
70
71struct mcast_device {
72	struct ib_device	*device;
73	struct ib_event_handler	event_handler;
74	int			start_port;
75	int			end_port;
76	struct mcast_port	port[0];
77};
78
79enum mcast_state {
80	MCAST_JOINING,
81	MCAST_MEMBER,
82	MCAST_ERROR,
83};
84
85enum mcast_group_state {
86	MCAST_IDLE,
87	MCAST_BUSY,
88	MCAST_GROUP_ERROR,
89	MCAST_PKEY_EVENT
90};
91
92enum {
93	MCAST_INVALID_PKEY_INDEX = 0xFFFF
94};
95
96struct mcast_member;
97
98struct mcast_group {
99	struct ib_sa_mcmember_rec rec;
100	struct rb_node		node;
101	struct mcast_port	*port;
102	spinlock_t		lock;
103	struct work_struct	work;
104	struct list_head	pending_list;
105	struct list_head	active_list;
106	struct mcast_member	*last_join;
107	int			members[NUM_JOIN_MEMBERSHIP_TYPES];
108	atomic_t		refcount;
109	enum mcast_group_state	state;
110	struct ib_sa_query	*query;
111	u16			pkey_index;
112	u8			leave_state;
113	int			retries;
114};
115
116struct mcast_member {
117	struct ib_sa_multicast	multicast;
118	struct ib_sa_client	*client;
119	struct mcast_group	*group;
120	struct list_head	list;
121	enum mcast_state	state;
122	atomic_t		refcount;
123	struct completion	comp;
124};
125
126static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
127			 void *context);
128static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
129			  void *context);
130
131static struct mcast_group *mcast_find(struct mcast_port *port,
132				      union ib_gid *mgid)
133{
134	struct rb_node *node = port->table.rb_node;
135	struct mcast_group *group;
136	int ret;
137
138	while (node) {
139		group = rb_entry(node, struct mcast_group, node);
140		ret = memcmp(mgid->raw, group->rec.mgid.raw, sizeof *mgid);
141		if (!ret)
142			return group;
143
144		if (ret < 0)
145			node = node->rb_left;
146		else
147			node = node->rb_right;
148	}
149	return NULL;
150}
151
152static struct mcast_group *mcast_insert(struct mcast_port *port,
153					struct mcast_group *group,
154					int allow_duplicates)
155{
156	struct rb_node **link = &port->table.rb_node;
157	struct rb_node *parent = NULL;
158	struct mcast_group *cur_group;
159	int ret;
160
161	while (*link) {
162		parent = *link;
163		cur_group = rb_entry(parent, struct mcast_group, node);
164
165		ret = memcmp(group->rec.mgid.raw, cur_group->rec.mgid.raw,
166			     sizeof group->rec.mgid);
167		if (ret < 0)
168			link = &(*link)->rb_left;
169		else if (ret > 0)
170			link = &(*link)->rb_right;
171		else if (allow_duplicates)
172			link = &(*link)->rb_left;
173		else
174			return cur_group;
175	}
176	rb_link_node(&group->node, parent, link);
177	rb_insert_color(&group->node, &port->table);
178	return NULL;
179}
180
181static void deref_port(struct mcast_port *port)
182{
183	if (atomic_dec_and_test(&port->refcount))
184		complete(&port->comp);
185}
186
187static void release_group(struct mcast_group *group)
188{
189	struct mcast_port *port = group->port;
190	unsigned long flags;
191
192	spin_lock_irqsave(&port->lock, flags);
193	if (atomic_dec_and_test(&group->refcount)) {
194		rb_erase(&group->node, &port->table);
195		spin_unlock_irqrestore(&port->lock, flags);
196		kfree(group);
197		deref_port(port);
198	} else
199		spin_unlock_irqrestore(&port->lock, flags);
200}
201
202static void deref_member(struct mcast_member *member)
203{
204	if (atomic_dec_and_test(&member->refcount))
205		complete(&member->comp);
206}
207
208static void queue_join(struct mcast_member *member)
209{
210	struct mcast_group *group = member->group;
211	unsigned long flags;
212
213	spin_lock_irqsave(&group->lock, flags);
214	list_add_tail(&member->list, &group->pending_list);
215	if (group->state == MCAST_IDLE) {
216		group->state = MCAST_BUSY;
217		atomic_inc(&group->refcount);
218		queue_work(mcast_wq, &group->work);
219	}
220	spin_unlock_irqrestore(&group->lock, flags);
221}
222
223/*
224 * A multicast group has four types of members: full member, non member,
225 * sendonly non member and sendonly full member.
226 * We need to keep track of the number of members of each
227 * type based on their join state.  Adjust the number of members the belong to
228 * the specified join states.
229 */
230static void adjust_membership(struct mcast_group *group, u8 join_state, int inc)
231{
232	int i;
233
234	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++, join_state >>= 1)
235		if (join_state & 0x1)
236			group->members[i] += inc;
237}
238
239/*
240 * If a multicast group has zero members left for a particular join state, but
241 * the group is still a member with the SA, we need to leave that join state.
242 * Determine which join states we still belong to, but that do not have any
243 * active members.
244 */
245static u8 get_leave_state(struct mcast_group *group)
246{
247	u8 leave_state = 0;
248	int i;
249
250	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++)
251		if (!group->members[i])
252			leave_state |= (0x1 << i);
253
254	return leave_state & group->rec.join_state;
255}
256
257static int check_selector(ib_sa_comp_mask comp_mask,
258			  ib_sa_comp_mask selector_mask,
259			  ib_sa_comp_mask value_mask,
260			  u8 selector, u8 src_value, u8 dst_value)
261{
262	int err;
263
264	if (!(comp_mask & selector_mask) || !(comp_mask & value_mask))
265		return 0;
266
267	switch (selector) {
268	case IB_SA_GT:
269		err = (src_value <= dst_value);
270		break;
271	case IB_SA_LT:
272		err = (src_value >= dst_value);
273		break;
274	case IB_SA_EQ:
275		err = (src_value != dst_value);
276		break;
277	default:
278		err = 0;
279		break;
280	}
281
282	return err;
283}
284
285static int cmp_rec(struct ib_sa_mcmember_rec *src,
286		   struct ib_sa_mcmember_rec *dst, ib_sa_comp_mask comp_mask)
287{
288	/* MGID must already match */
289
290	if (comp_mask & IB_SA_MCMEMBER_REC_PORT_GID &&
291	    memcmp(&src->port_gid, &dst->port_gid, sizeof src->port_gid))
292		return -EINVAL;
293	if (comp_mask & IB_SA_MCMEMBER_REC_QKEY && src->qkey != dst->qkey)
294		return -EINVAL;
295	if (comp_mask & IB_SA_MCMEMBER_REC_MLID && src->mlid != dst->mlid)
296		return -EINVAL;
297	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_MTU_SELECTOR,
298			   IB_SA_MCMEMBER_REC_MTU, dst->mtu_selector,
299			   src->mtu, dst->mtu))
300		return -EINVAL;
301	if (comp_mask & IB_SA_MCMEMBER_REC_TRAFFIC_CLASS &&
302	    src->traffic_class != dst->traffic_class)
303		return -EINVAL;
304	if (comp_mask & IB_SA_MCMEMBER_REC_PKEY && src->pkey != dst->pkey)
305		return -EINVAL;
306	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_RATE_SELECTOR,
307			   IB_SA_MCMEMBER_REC_RATE, dst->rate_selector,
308			   src->rate, dst->rate))
309		return -EINVAL;
310	if (check_selector(comp_mask,
311			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME_SELECTOR,
312			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME,
313			   dst->packet_life_time_selector,
314			   src->packet_life_time, dst->packet_life_time))
315		return -EINVAL;
316	if (comp_mask & IB_SA_MCMEMBER_REC_SL && src->sl != dst->sl)
317		return -EINVAL;
318	if (comp_mask & IB_SA_MCMEMBER_REC_FLOW_LABEL &&
319	    src->flow_label != dst->flow_label)
320		return -EINVAL;
321	if (comp_mask & IB_SA_MCMEMBER_REC_HOP_LIMIT &&
322	    src->hop_limit != dst->hop_limit)
323		return -EINVAL;
324	if (comp_mask & IB_SA_MCMEMBER_REC_SCOPE && src->scope != dst->scope)
325		return -EINVAL;
326
327	/* join_state checked separately, proxy_join ignored */
328
329	return 0;
330}
331
332static int send_join(struct mcast_group *group, struct mcast_member *member)
333{
334	struct mcast_port *port = group->port;
335	int ret;
336
337	group->last_join = member;
338	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
339				       port->port_num, IB_MGMT_METHOD_SET,
340				       &member->multicast.rec,
341				       member->multicast.comp_mask,
342				       3000, GFP_KERNEL, join_handler, group,
343				       &group->query);
344	return (ret > 0) ? 0 : ret;
345}
346
347static int send_leave(struct mcast_group *group, u8 leave_state)
348{
349	struct mcast_port *port = group->port;
350	struct ib_sa_mcmember_rec rec;
351	int ret;
352
353	rec = group->rec;
354	rec.join_state = leave_state;
355	group->leave_state = leave_state;
356
357	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
358				       port->port_num, IB_SA_METHOD_DELETE, &rec,
359				       IB_SA_MCMEMBER_REC_MGID     |
360				       IB_SA_MCMEMBER_REC_PORT_GID |
361				       IB_SA_MCMEMBER_REC_JOIN_STATE,
362				       3000, GFP_KERNEL, leave_handler,
363				       group, &group->query);
364	return (ret > 0) ? 0 : ret;
365}
366
367static void join_group(struct mcast_group *group, struct mcast_member *member,
368		       u8 join_state)
369{
370	member->state = MCAST_MEMBER;
371	adjust_membership(group, join_state, 1);
372	group->rec.join_state |= join_state;
373	member->multicast.rec = group->rec;
374	member->multicast.rec.join_state = join_state;
375	list_move(&member->list, &group->active_list);
376}
377
378static int fail_join(struct mcast_group *group, struct mcast_member *member,
379		     int status)
380{
381	spin_lock_irq(&group->lock);
382	list_del_init(&member->list);
383	spin_unlock_irq(&group->lock);
384	return member->multicast.callback(status, &member->multicast);
385}
386
387static void process_group_error(struct mcast_group *group)
388{
389	struct mcast_member *member;
390	int ret = 0;
391	u16 pkey_index;
392
393	if (group->state == MCAST_PKEY_EVENT)
394		ret = ib_find_pkey(group->port->dev->device,
395				   group->port->port_num,
396				   be16_to_cpu(group->rec.pkey), &pkey_index);
397
398	spin_lock_irq(&group->lock);
399	if (group->state == MCAST_PKEY_EVENT && !ret &&
400	    group->pkey_index == pkey_index)
401		goto out;
402
403	while (!list_empty(&group->active_list)) {
404		member = list_entry(group->active_list.next,
405				    struct mcast_member, list);
406		atomic_inc(&member->refcount);
407		list_del_init(&member->list);
408		adjust_membership(group, member->multicast.rec.join_state, -1);
409		member->state = MCAST_ERROR;
410		spin_unlock_irq(&group->lock);
411
412		ret = member->multicast.callback(-ENETRESET,
413						 &member->multicast);
414		deref_member(member);
415		if (ret)
416			ib_sa_free_multicast(&member->multicast);
417		spin_lock_irq(&group->lock);
418	}
419
420	group->rec.join_state = 0;
421out:
422	group->state = MCAST_BUSY;
423	spin_unlock_irq(&group->lock);
424}
425
426static void mcast_work_handler(struct work_struct *work)
427{
428	struct mcast_group *group;
429	struct mcast_member *member;
430	struct ib_sa_multicast *multicast;
431	int status, ret;
432	u8 join_state;
433
434	group = container_of(work, typeof(*group), work);
435retest:
436	spin_lock_irq(&group->lock);
437	while (!list_empty(&group->pending_list) ||
438	       (group->state != MCAST_BUSY)) {
439
440		if (group->state != MCAST_BUSY) {
441			spin_unlock_irq(&group->lock);
442			process_group_error(group);
443			goto retest;
444		}
445
446		member = list_entry(group->pending_list.next,
447				    struct mcast_member, list);
448		multicast = &member->multicast;
449		join_state = multicast->rec.join_state;
450		atomic_inc(&member->refcount);
451
452		if (join_state == (group->rec.join_state & join_state)) {
453			status = cmp_rec(&group->rec, &multicast->rec,
454					 multicast->comp_mask);
455			if (!status)
456				join_group(group, member, join_state);
457			else
458				list_del_init(&member->list);
459			spin_unlock_irq(&group->lock);
460			ret = multicast->callback(status, multicast);
461		} else {
462			spin_unlock_irq(&group->lock);
463			status = send_join(group, member);
464			if (!status) {
465				deref_member(member);
466				return;
467			}
468			ret = fail_join(group, member, status);
469		}
470
471		deref_member(member);
472		if (ret)
473			ib_sa_free_multicast(&member->multicast);
474		spin_lock_irq(&group->lock);
475	}
476
477	join_state = get_leave_state(group);
478	if (join_state) {
479		group->rec.join_state &= ~join_state;
480		spin_unlock_irq(&group->lock);
481		if (send_leave(group, join_state))
482			goto retest;
483	} else {
484		group->state = MCAST_IDLE;
485		spin_unlock_irq(&group->lock);
486		release_group(group);
487	}
488}
489
490/*
491 * Fail a join request if it is still active - at the head of the pending queue.
492 */
493static void process_join_error(struct mcast_group *group, int status)
494{
495	struct mcast_member *member;
496	int ret;
497
498	spin_lock_irq(&group->lock);
499	member = list_entry(group->pending_list.next,
500			    struct mcast_member, list);
501	if (group->last_join == member) {
502		atomic_inc(&member->refcount);
503		list_del_init(&member->list);
504		spin_unlock_irq(&group->lock);
505		ret = member->multicast.callback(status, &member->multicast);
506		deref_member(member);
507		if (ret)
508			ib_sa_free_multicast(&member->multicast);
509	} else
510		spin_unlock_irq(&group->lock);
511}
512
513static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
514			 void *context)
515{
516	struct mcast_group *group = context;
517	u16 pkey_index = MCAST_INVALID_PKEY_INDEX;
518
519	if (status)
520		process_join_error(group, status);
521	else {
522		int mgids_changed, is_mgid0;
523		ib_find_pkey(group->port->dev->device, group->port->port_num,
524			     be16_to_cpu(rec->pkey), &pkey_index);
525
526		spin_lock_irq(&group->port->lock);
527		if (group->state == MCAST_BUSY &&
528		    group->pkey_index == MCAST_INVALID_PKEY_INDEX)
529			group->pkey_index = pkey_index;
530		mgids_changed = memcmp(&rec->mgid, &group->rec.mgid,
531				       sizeof(group->rec.mgid));
532		group->rec = *rec;
533		if (mgids_changed) {
534			rb_erase(&group->node, &group->port->table);
535			is_mgid0 = !memcmp(&mgid0, &group->rec.mgid,
536					   sizeof(mgid0));
537			mcast_insert(group->port, group, is_mgid0);
538		}
539		spin_unlock_irq(&group->port->lock);
540	}
541	mcast_work_handler(&group->work);
542}
543
544static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
545			  void *context)
546{
547	struct mcast_group *group = context;
548
549	if (status && group->retries > 0 &&
550	    !send_leave(group, group->leave_state))
551		group->retries--;
552	else
553		mcast_work_handler(&group->work);
554}
555
556static struct mcast_group *acquire_group(struct mcast_port *port,
557					 union ib_gid *mgid, gfp_t gfp_mask)
558{
559	struct mcast_group *group, *cur_group;
560	unsigned long flags;
561	int is_mgid0;
562
563	is_mgid0 = !memcmp(&mgid0, mgid, sizeof mgid0);
564	if (!is_mgid0) {
565		spin_lock_irqsave(&port->lock, flags);
566		group = mcast_find(port, mgid);
567		if (group)
568			goto found;
569		spin_unlock_irqrestore(&port->lock, flags);
570	}
571
572	group = kzalloc(sizeof *group, gfp_mask);
573	if (!group)
574		return NULL;
575
576	group->retries = 3;
577	group->port = port;
578	group->rec.mgid = *mgid;
579	group->pkey_index = MCAST_INVALID_PKEY_INDEX;
580	INIT_LIST_HEAD(&group->pending_list);
581	INIT_LIST_HEAD(&group->active_list);
582	INIT_WORK(&group->work, mcast_work_handler);
583	spin_lock_init(&group->lock);
584
585	spin_lock_irqsave(&port->lock, flags);
586	cur_group = mcast_insert(port, group, is_mgid0);
587	if (cur_group) {
588		kfree(group);
589		group = cur_group;
590	} else
591		atomic_inc(&port->refcount);
592found:
593	atomic_inc(&group->refcount);
594	spin_unlock_irqrestore(&port->lock, flags);
595	return group;
596}
597
598/*
599 * We serialize all join requests to a single group to make our lives much
600 * easier.  Otherwise, two users could try to join the same group
601 * simultaneously, with different configurations, one could leave while the
602 * join is in progress, etc., which makes locking around error recovery
603 * difficult.
604 */
605struct ib_sa_multicast *
606ib_sa_join_multicast(struct ib_sa_client *client,
607		     struct ib_device *device, u8 port_num,
608		     struct ib_sa_mcmember_rec *rec,
609		     ib_sa_comp_mask comp_mask, gfp_t gfp_mask,
610		     int (*callback)(int status,
611				     struct ib_sa_multicast *multicast),
612		     void *context)
613{
614	struct mcast_device *dev;
615	struct mcast_member *member;
616	struct ib_sa_multicast *multicast;
617	int ret;
618
619	dev = ib_get_client_data(device, &mcast_client);
620	if (!dev)
621		return ERR_PTR(-ENODEV);
622
623	member = kmalloc(sizeof *member, gfp_mask);
624	if (!member)
625		return ERR_PTR(-ENOMEM);
626
627	ib_sa_client_get(client);
628	member->client = client;
629	member->multicast.rec = *rec;
630	member->multicast.comp_mask = comp_mask;
631	member->multicast.callback = callback;
632	member->multicast.context = context;
633	init_completion(&member->comp);
634	atomic_set(&member->refcount, 1);
635	member->state = MCAST_JOINING;
636
637	member->group = acquire_group(&dev->port[port_num - dev->start_port],
638				      &rec->mgid, gfp_mask);
639	if (!member->group) {
640		ret = -ENOMEM;
641		goto err;
642	}
643
644	/*
645	 * The user will get the multicast structure in their callback.  They
646	 * could then free the multicast structure before we can return from
647	 * this routine.  So we save the pointer to return before queuing
648	 * any callback.
649	 */
650	multicast = &member->multicast;
651	queue_join(member);
652	return multicast;
653
654err:
655	ib_sa_client_put(client);
656	kfree(member);
657	return ERR_PTR(ret);
658}
659EXPORT_SYMBOL(ib_sa_join_multicast);
660
661void ib_sa_free_multicast(struct ib_sa_multicast *multicast)
662{
663	struct mcast_member *member;
664	struct mcast_group *group;
665
666	member = container_of(multicast, struct mcast_member, multicast);
667	group = member->group;
668
669	spin_lock_irq(&group->lock);
670	if (member->state == MCAST_MEMBER)
671		adjust_membership(group, multicast->rec.join_state, -1);
672
673	list_del_init(&member->list);
674
675	if (group->state == MCAST_IDLE) {
676		group->state = MCAST_BUSY;
677		spin_unlock_irq(&group->lock);
678		/* Continue to hold reference on group until callback */
679		queue_work(mcast_wq, &group->work);
680	} else {
681		spin_unlock_irq(&group->lock);
682		release_group(group);
683	}
684
685	deref_member(member);
686	wait_for_completion(&member->comp);
687	ib_sa_client_put(member->client);
688	kfree(member);
689}
690EXPORT_SYMBOL(ib_sa_free_multicast);
691
692int ib_sa_get_mcmember_rec(struct ib_device *device, u8 port_num,
693			   union ib_gid *mgid, struct ib_sa_mcmember_rec *rec)
694{
695	struct mcast_device *dev;
696	struct mcast_port *port;
697	struct mcast_group *group;
698	unsigned long flags;
699	int ret = 0;
700
701	dev = ib_get_client_data(device, &mcast_client);
702	if (!dev)
703		return -ENODEV;
704
705	port = &dev->port[port_num - dev->start_port];
706	spin_lock_irqsave(&port->lock, flags);
707	group = mcast_find(port, mgid);
708	if (group)
709		*rec = group->rec;
710	else
711		ret = -EADDRNOTAVAIL;
712	spin_unlock_irqrestore(&port->lock, flags);
713
714	return ret;
715}
716EXPORT_SYMBOL(ib_sa_get_mcmember_rec);
717
718int ib_init_ah_from_mcmember(struct ib_device *device, u8 port_num,
719			     struct ib_sa_mcmember_rec *rec,
720			     struct net_device *ndev,
721			     enum ib_gid_type gid_type,
722			     struct ib_ah_attr *ah_attr)
723{
724	int ret;
725	u16 gid_index;
726	u8 p;
727
728	if (rdma_protocol_roce(device, port_num)) {
729		ret = ib_find_cached_gid_by_port(device, &rec->port_gid,
730						 gid_type, port_num,
731						 ndev,
732						 &gid_index);
733	} else if (rdma_protocol_ib(device, port_num)) {
734		ret = ib_find_cached_gid(device, &rec->port_gid,
735					 IB_GID_TYPE_IB, NULL, &p,
736					 &gid_index);
737	} else {
738		ret = -EINVAL;
739	}
740
741	if (ret)
742		return ret;
743
744	memset(ah_attr, 0, sizeof *ah_attr);
745	ah_attr->dlid = be16_to_cpu(rec->mlid);
746	ah_attr->sl = rec->sl;
747	ah_attr->port_num = port_num;
748	ah_attr->static_rate = rec->rate;
749
750	ah_attr->ah_flags = IB_AH_GRH;
751	ah_attr->grh.dgid = rec->mgid;
752
753	ah_attr->grh.sgid_index = (u8) gid_index;
754	ah_attr->grh.flow_label = be32_to_cpu(rec->flow_label);
755	ah_attr->grh.hop_limit = rec->hop_limit;
756	ah_attr->grh.traffic_class = rec->traffic_class;
757
758	return 0;
759}
760EXPORT_SYMBOL(ib_init_ah_from_mcmember);
761
762static void mcast_groups_event(struct mcast_port *port,
763			       enum mcast_group_state state)
764{
765	struct mcast_group *group;
766	struct rb_node *node;
767	unsigned long flags;
768
769	spin_lock_irqsave(&port->lock, flags);
770	for (node = rb_first(&port->table); node; node = rb_next(node)) {
771		group = rb_entry(node, struct mcast_group, node);
772		spin_lock(&group->lock);
773		if (group->state == MCAST_IDLE) {
774			atomic_inc(&group->refcount);
775			queue_work(mcast_wq, &group->work);
776		}
777		if (group->state != MCAST_GROUP_ERROR)
778			group->state = state;
779		spin_unlock(&group->lock);
780	}
781	spin_unlock_irqrestore(&port->lock, flags);
782}
783
784static void mcast_event_handler(struct ib_event_handler *handler,
785				struct ib_event *event)
786{
787	struct mcast_device *dev;
788	int index;
789
790	dev = container_of(handler, struct mcast_device, event_handler);
791	if (!rdma_cap_ib_mcast(dev->device, event->element.port_num))
792		return;
793
794	index = event->element.port_num - dev->start_port;
795
796	switch (event->event) {
797	case IB_EVENT_PORT_ERR:
798	case IB_EVENT_LID_CHANGE:
799	case IB_EVENT_SM_CHANGE:
800	case IB_EVENT_CLIENT_REREGISTER:
801		mcast_groups_event(&dev->port[index], MCAST_GROUP_ERROR);
802		break;
803	case IB_EVENT_PKEY_CHANGE:
804		mcast_groups_event(&dev->port[index], MCAST_PKEY_EVENT);
805		break;
806	default:
807		break;
808	}
809}
810
811static void mcast_add_one(struct ib_device *device)
812{
813	struct mcast_device *dev;
814	struct mcast_port *port;
815	int i;
816	int count = 0;
817
818	dev = kmalloc(sizeof *dev + device->phys_port_cnt * sizeof *port,
819		      GFP_KERNEL);
820	if (!dev)
821		return;
822
823	dev->start_port = rdma_start_port(device);
824	dev->end_port = rdma_end_port(device);
825
826	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
827		if (!rdma_cap_ib_mcast(device, dev->start_port + i))
828			continue;
829		port = &dev->port[i];
830		port->dev = dev;
831		port->port_num = dev->start_port + i;
832		spin_lock_init(&port->lock);
833		port->table = RB_ROOT;
834		init_completion(&port->comp);
835		atomic_set(&port->refcount, 1);
836		++count;
837	}
838
839	if (!count) {
840		kfree(dev);
841		return;
842	}
843
844	dev->device = device;
845	ib_set_client_data(device, &mcast_client, dev);
846
847	INIT_IB_EVENT_HANDLER(&dev->event_handler, device, mcast_event_handler);
848	ib_register_event_handler(&dev->event_handler);
849}
850
851static void mcast_remove_one(struct ib_device *device, void *client_data)
852{
853	struct mcast_device *dev = client_data;
854	struct mcast_port *port;
855	int i;
856
857	if (!dev)
858		return;
859
860	ib_unregister_event_handler(&dev->event_handler);
861	flush_workqueue(mcast_wq);
862
863	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
864		if (rdma_cap_ib_mcast(device, dev->start_port + i)) {
865			port = &dev->port[i];
866			deref_port(port);
867			wait_for_completion(&port->comp);
868		}
869	}
870
871	kfree(dev);
872}
873
874int mcast_init(void)
875{
876	int ret;
877
878	mcast_wq = alloc_ordered_workqueue("ib_mcast", WQ_MEM_RECLAIM);
879	if (!mcast_wq)
880		return -ENOMEM;
881
882	ib_sa_register_client(&sa_client);
883
884	ret = ib_register_client(&mcast_client);
885	if (ret)
886		goto err;
887	return 0;
888
889err:
890	ib_sa_unregister_client(&sa_client);
891	destroy_workqueue(mcast_wq);
892	return ret;
893}
894
895void mcast_cleanup(void)
896{
897	ib_unregister_client(&mcast_client);
898	ib_sa_unregister_client(&sa_client);
899	destroy_workqueue(mcast_wq);
900}
901