cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

multicast.c (24055B)


      1/*
      2 * Copyright (c) 2006 Intel Corporation.  All rights reserved.
      3 *
      4 * This software is available to you under a choice of one of two
      5 * licenses.  You may choose to be licensed under the terms of the GNU
      6 * General Public License (GPL) Version 2, available from the file
      7 * COPYING in the main directory of this source tree, or the
      8 * OpenIB.org BSD license below:
      9 *
     10 *     Redistribution and use in source and binary forms, with or
     11 *     without modification, are permitted provided that the following
     12 *     conditions are met:
     13 *
     14 *      - Redistributions of source code must retain the above
     15 *        copyright notice, this list of conditions and the following
     16 *        disclaimer.
     17 *
     18 *      - Redistributions in binary form must reproduce the above
     19 *        copyright notice, this list of conditions and the following
     20 *        disclaimer in the documentation and/or other materials
     21 *        provided with the distribution.
     22 *
     23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     30 * SOFTWARE.
     31 */
     32
     33#include <linux/completion.h>
     34#include <linux/dma-mapping.h>
     35#include <linux/err.h>
     36#include <linux/interrupt.h>
     37#include <linux/export.h>
     38#include <linux/slab.h>
     39#include <linux/bitops.h>
     40#include <linux/random.h>
     41
     42#include <rdma/ib_cache.h>
     43#include "sa.h"
     44
     45static int mcast_add_one(struct ib_device *device);
     46static void mcast_remove_one(struct ib_device *device, void *client_data);
     47
     48static struct ib_client mcast_client = {
     49	.name   = "ib_multicast",
     50	.add    = mcast_add_one,
     51	.remove = mcast_remove_one
     52};
     53
     54static struct ib_sa_client	sa_client;
     55static struct workqueue_struct	*mcast_wq;
     56static union ib_gid mgid0;
     57
     58struct mcast_device;
     59
     60struct mcast_port {
     61	struct mcast_device	*dev;
     62	spinlock_t		lock;
     63	struct rb_root		table;
     64	refcount_t		refcount;
     65	struct completion	comp;
     66	u32			port_num;
     67};
     68
     69struct mcast_device {
     70	struct ib_device	*device;
     71	struct ib_event_handler	event_handler;
     72	int			start_port;
     73	int			end_port;
     74	struct mcast_port	port[];
     75};
     76
     77enum mcast_state {
     78	MCAST_JOINING,
     79	MCAST_MEMBER,
     80	MCAST_ERROR,
     81};
     82
     83enum mcast_group_state {
     84	MCAST_IDLE,
     85	MCAST_BUSY,
     86	MCAST_GROUP_ERROR,
     87	MCAST_PKEY_EVENT
     88};
     89
     90enum {
     91	MCAST_INVALID_PKEY_INDEX = 0xFFFF
     92};
     93
     94struct mcast_member;
     95
     96struct mcast_group {
     97	struct ib_sa_mcmember_rec rec;
     98	struct rb_node		node;
     99	struct mcast_port	*port;
    100	spinlock_t		lock;
    101	struct work_struct	work;
    102	struct list_head	pending_list;
    103	struct list_head	active_list;
    104	struct mcast_member	*last_join;
    105	int			members[NUM_JOIN_MEMBERSHIP_TYPES];
    106	atomic_t		refcount;
    107	enum mcast_group_state	state;
    108	struct ib_sa_query	*query;
    109	u16			pkey_index;
    110	u8			leave_state;
    111	int			retries;
    112};
    113
    114struct mcast_member {
    115	struct ib_sa_multicast	multicast;
    116	struct ib_sa_client	*client;
    117	struct mcast_group	*group;
    118	struct list_head	list;
    119	enum mcast_state	state;
    120	refcount_t		refcount;
    121	struct completion	comp;
    122};
    123
    124static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
    125			 void *context);
    126static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
    127			  void *context);
    128
    129static struct mcast_group *mcast_find(struct mcast_port *port,
    130				      union ib_gid *mgid)
    131{
    132	struct rb_node *node = port->table.rb_node;
    133	struct mcast_group *group;
    134	int ret;
    135
    136	while (node) {
    137		group = rb_entry(node, struct mcast_group, node);
    138		ret = memcmp(mgid->raw, group->rec.mgid.raw, sizeof *mgid);
    139		if (!ret)
    140			return group;
    141
    142		if (ret < 0)
    143			node = node->rb_left;
    144		else
    145			node = node->rb_right;
    146	}
    147	return NULL;
    148}
    149
    150static struct mcast_group *mcast_insert(struct mcast_port *port,
    151					struct mcast_group *group,
    152					int allow_duplicates)
    153{
    154	struct rb_node **link = &port->table.rb_node;
    155	struct rb_node *parent = NULL;
    156	struct mcast_group *cur_group;
    157	int ret;
    158
    159	while (*link) {
    160		parent = *link;
    161		cur_group = rb_entry(parent, struct mcast_group, node);
    162
    163		ret = memcmp(group->rec.mgid.raw, cur_group->rec.mgid.raw,
    164			     sizeof group->rec.mgid);
    165		if (ret < 0)
    166			link = &(*link)->rb_left;
    167		else if (ret > 0)
    168			link = &(*link)->rb_right;
    169		else if (allow_duplicates)
    170			link = &(*link)->rb_left;
    171		else
    172			return cur_group;
    173	}
    174	rb_link_node(&group->node, parent, link);
    175	rb_insert_color(&group->node, &port->table);
    176	return NULL;
    177}
    178
    179static void deref_port(struct mcast_port *port)
    180{
    181	if (refcount_dec_and_test(&port->refcount))
    182		complete(&port->comp);
    183}
    184
    185static void release_group(struct mcast_group *group)
    186{
    187	struct mcast_port *port = group->port;
    188	unsigned long flags;
    189
    190	spin_lock_irqsave(&port->lock, flags);
    191	if (atomic_dec_and_test(&group->refcount)) {
    192		rb_erase(&group->node, &port->table);
    193		spin_unlock_irqrestore(&port->lock, flags);
    194		kfree(group);
    195		deref_port(port);
    196	} else
    197		spin_unlock_irqrestore(&port->lock, flags);
    198}
    199
    200static void deref_member(struct mcast_member *member)
    201{
    202	if (refcount_dec_and_test(&member->refcount))
    203		complete(&member->comp);
    204}
    205
    206static void queue_join(struct mcast_member *member)
    207{
    208	struct mcast_group *group = member->group;
    209	unsigned long flags;
    210
    211	spin_lock_irqsave(&group->lock, flags);
    212	list_add_tail(&member->list, &group->pending_list);
    213	if (group->state == MCAST_IDLE) {
    214		group->state = MCAST_BUSY;
    215		atomic_inc(&group->refcount);
    216		queue_work(mcast_wq, &group->work);
    217	}
    218	spin_unlock_irqrestore(&group->lock, flags);
    219}
    220
    221/*
    222 * A multicast group has four types of members: full member, non member,
    223 * sendonly non member and sendonly full member.
    224 * We need to keep track of the number of members of each
    225 * type based on their join state.  Adjust the number of members the belong to
    226 * the specified join states.
    227 */
    228static void adjust_membership(struct mcast_group *group, u8 join_state, int inc)
    229{
    230	int i;
    231
    232	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++, join_state >>= 1)
    233		if (join_state & 0x1)
    234			group->members[i] += inc;
    235}
    236
    237/*
    238 * If a multicast group has zero members left for a particular join state, but
    239 * the group is still a member with the SA, we need to leave that join state.
    240 * Determine which join states we still belong to, but that do not have any
    241 * active members.
    242 */
    243static u8 get_leave_state(struct mcast_group *group)
    244{
    245	u8 leave_state = 0;
    246	int i;
    247
    248	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++)
    249		if (!group->members[i])
    250			leave_state |= (0x1 << i);
    251
    252	return leave_state & group->rec.join_state;
    253}
    254
    255static int check_selector(ib_sa_comp_mask comp_mask,
    256			  ib_sa_comp_mask selector_mask,
    257			  ib_sa_comp_mask value_mask,
    258			  u8 selector, u8 src_value, u8 dst_value)
    259{
    260	int err;
    261
    262	if (!(comp_mask & selector_mask) || !(comp_mask & value_mask))
    263		return 0;
    264
    265	switch (selector) {
    266	case IB_SA_GT:
    267		err = (src_value <= dst_value);
    268		break;
    269	case IB_SA_LT:
    270		err = (src_value >= dst_value);
    271		break;
    272	case IB_SA_EQ:
    273		err = (src_value != dst_value);
    274		break;
    275	default:
    276		err = 0;
    277		break;
    278	}
    279
    280	return err;
    281}
    282
    283static int cmp_rec(struct ib_sa_mcmember_rec *src,
    284		   struct ib_sa_mcmember_rec *dst, ib_sa_comp_mask comp_mask)
    285{
    286	/* MGID must already match */
    287
    288	if (comp_mask & IB_SA_MCMEMBER_REC_PORT_GID &&
    289	    memcmp(&src->port_gid, &dst->port_gid, sizeof src->port_gid))
    290		return -EINVAL;
    291	if (comp_mask & IB_SA_MCMEMBER_REC_QKEY && src->qkey != dst->qkey)
    292		return -EINVAL;
    293	if (comp_mask & IB_SA_MCMEMBER_REC_MLID && src->mlid != dst->mlid)
    294		return -EINVAL;
    295	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_MTU_SELECTOR,
    296			   IB_SA_MCMEMBER_REC_MTU, dst->mtu_selector,
    297			   src->mtu, dst->mtu))
    298		return -EINVAL;
    299	if (comp_mask & IB_SA_MCMEMBER_REC_TRAFFIC_CLASS &&
    300	    src->traffic_class != dst->traffic_class)
    301		return -EINVAL;
    302	if (comp_mask & IB_SA_MCMEMBER_REC_PKEY && src->pkey != dst->pkey)
    303		return -EINVAL;
    304	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_RATE_SELECTOR,
    305			   IB_SA_MCMEMBER_REC_RATE, dst->rate_selector,
    306			   src->rate, dst->rate))
    307		return -EINVAL;
    308	if (check_selector(comp_mask,
    309			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME_SELECTOR,
    310			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME,
    311			   dst->packet_life_time_selector,
    312			   src->packet_life_time, dst->packet_life_time))
    313		return -EINVAL;
    314	if (comp_mask & IB_SA_MCMEMBER_REC_SL && src->sl != dst->sl)
    315		return -EINVAL;
    316	if (comp_mask & IB_SA_MCMEMBER_REC_FLOW_LABEL &&
    317	    src->flow_label != dst->flow_label)
    318		return -EINVAL;
    319	if (comp_mask & IB_SA_MCMEMBER_REC_HOP_LIMIT &&
    320	    src->hop_limit != dst->hop_limit)
    321		return -EINVAL;
    322	if (comp_mask & IB_SA_MCMEMBER_REC_SCOPE && src->scope != dst->scope)
    323		return -EINVAL;
    324
    325	/* join_state checked separately, proxy_join ignored */
    326
    327	return 0;
    328}
    329
    330static int send_join(struct mcast_group *group, struct mcast_member *member)
    331{
    332	struct mcast_port *port = group->port;
    333	int ret;
    334
    335	group->last_join = member;
    336	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
    337				       port->port_num, IB_MGMT_METHOD_SET,
    338				       &member->multicast.rec,
    339				       member->multicast.comp_mask,
    340				       3000, GFP_KERNEL, join_handler, group,
    341				       &group->query);
    342	return (ret > 0) ? 0 : ret;
    343}
    344
    345static int send_leave(struct mcast_group *group, u8 leave_state)
    346{
    347	struct mcast_port *port = group->port;
    348	struct ib_sa_mcmember_rec rec;
    349	int ret;
    350
    351	rec = group->rec;
    352	rec.join_state = leave_state;
    353	group->leave_state = leave_state;
    354
    355	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
    356				       port->port_num, IB_SA_METHOD_DELETE, &rec,
    357				       IB_SA_MCMEMBER_REC_MGID     |
    358				       IB_SA_MCMEMBER_REC_PORT_GID |
    359				       IB_SA_MCMEMBER_REC_JOIN_STATE,
    360				       3000, GFP_KERNEL, leave_handler,
    361				       group, &group->query);
    362	return (ret > 0) ? 0 : ret;
    363}
    364
    365static void join_group(struct mcast_group *group, struct mcast_member *member,
    366		       u8 join_state)
    367{
    368	member->state = MCAST_MEMBER;
    369	adjust_membership(group, join_state, 1);
    370	group->rec.join_state |= join_state;
    371	member->multicast.rec = group->rec;
    372	member->multicast.rec.join_state = join_state;
    373	list_move(&member->list, &group->active_list);
    374}
    375
    376static int fail_join(struct mcast_group *group, struct mcast_member *member,
    377		     int status)
    378{
    379	spin_lock_irq(&group->lock);
    380	list_del_init(&member->list);
    381	spin_unlock_irq(&group->lock);
    382	return member->multicast.callback(status, &member->multicast);
    383}
    384
    385static void process_group_error(struct mcast_group *group)
    386{
    387	struct mcast_member *member;
    388	int ret = 0;
    389	u16 pkey_index;
    390
    391	if (group->state == MCAST_PKEY_EVENT)
    392		ret = ib_find_pkey(group->port->dev->device,
    393				   group->port->port_num,
    394				   be16_to_cpu(group->rec.pkey), &pkey_index);
    395
    396	spin_lock_irq(&group->lock);
    397	if (group->state == MCAST_PKEY_EVENT && !ret &&
    398	    group->pkey_index == pkey_index)
    399		goto out;
    400
    401	while (!list_empty(&group->active_list)) {
    402		member = list_entry(group->active_list.next,
    403				    struct mcast_member, list);
    404		refcount_inc(&member->refcount);
    405		list_del_init(&member->list);
    406		adjust_membership(group, member->multicast.rec.join_state, -1);
    407		member->state = MCAST_ERROR;
    408		spin_unlock_irq(&group->lock);
    409
    410		ret = member->multicast.callback(-ENETRESET,
    411						 &member->multicast);
    412		deref_member(member);
    413		if (ret)
    414			ib_sa_free_multicast(&member->multicast);
    415		spin_lock_irq(&group->lock);
    416	}
    417
    418	group->rec.join_state = 0;
    419out:
    420	group->state = MCAST_BUSY;
    421	spin_unlock_irq(&group->lock);
    422}
    423
    424static void mcast_work_handler(struct work_struct *work)
    425{
    426	struct mcast_group *group;
    427	struct mcast_member *member;
    428	struct ib_sa_multicast *multicast;
    429	int status, ret;
    430	u8 join_state;
    431
    432	group = container_of(work, typeof(*group), work);
    433retest:
    434	spin_lock_irq(&group->lock);
    435	while (!list_empty(&group->pending_list) ||
    436	       (group->state != MCAST_BUSY)) {
    437
    438		if (group->state != MCAST_BUSY) {
    439			spin_unlock_irq(&group->lock);
    440			process_group_error(group);
    441			goto retest;
    442		}
    443
    444		member = list_entry(group->pending_list.next,
    445				    struct mcast_member, list);
    446		multicast = &member->multicast;
    447		join_state = multicast->rec.join_state;
    448		refcount_inc(&member->refcount);
    449
    450		if (join_state == (group->rec.join_state & join_state)) {
    451			status = cmp_rec(&group->rec, &multicast->rec,
    452					 multicast->comp_mask);
    453			if (!status)
    454				join_group(group, member, join_state);
    455			else
    456				list_del_init(&member->list);
    457			spin_unlock_irq(&group->lock);
    458			ret = multicast->callback(status, multicast);
    459		} else {
    460			spin_unlock_irq(&group->lock);
    461			status = send_join(group, member);
    462			if (!status) {
    463				deref_member(member);
    464				return;
    465			}
    466			ret = fail_join(group, member, status);
    467		}
    468
    469		deref_member(member);
    470		if (ret)
    471			ib_sa_free_multicast(&member->multicast);
    472		spin_lock_irq(&group->lock);
    473	}
    474
    475	join_state = get_leave_state(group);
    476	if (join_state) {
    477		group->rec.join_state &= ~join_state;
    478		spin_unlock_irq(&group->lock);
    479		if (send_leave(group, join_state))
    480			goto retest;
    481	} else {
    482		group->state = MCAST_IDLE;
    483		spin_unlock_irq(&group->lock);
    484		release_group(group);
    485	}
    486}
    487
    488/*
    489 * Fail a join request if it is still active - at the head of the pending queue.
    490 */
    491static void process_join_error(struct mcast_group *group, int status)
    492{
    493	struct mcast_member *member;
    494	int ret;
    495
    496	spin_lock_irq(&group->lock);
    497	member = list_entry(group->pending_list.next,
    498			    struct mcast_member, list);
    499	if (group->last_join == member) {
    500		refcount_inc(&member->refcount);
    501		list_del_init(&member->list);
    502		spin_unlock_irq(&group->lock);
    503		ret = member->multicast.callback(status, &member->multicast);
    504		deref_member(member);
    505		if (ret)
    506			ib_sa_free_multicast(&member->multicast);
    507	} else
    508		spin_unlock_irq(&group->lock);
    509}
    510
    511static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
    512			 void *context)
    513{
    514	struct mcast_group *group = context;
    515	u16 pkey_index = MCAST_INVALID_PKEY_INDEX;
    516
    517	if (status)
    518		process_join_error(group, status);
    519	else {
    520		int mgids_changed, is_mgid0;
    521
    522		if (ib_find_pkey(group->port->dev->device,
    523				 group->port->port_num, be16_to_cpu(rec->pkey),
    524				 &pkey_index))
    525			pkey_index = MCAST_INVALID_PKEY_INDEX;
    526
    527		spin_lock_irq(&group->port->lock);
    528		if (group->state == MCAST_BUSY &&
    529		    group->pkey_index == MCAST_INVALID_PKEY_INDEX)
    530			group->pkey_index = pkey_index;
    531		mgids_changed = memcmp(&rec->mgid, &group->rec.mgid,
    532				       sizeof(group->rec.mgid));
    533		group->rec = *rec;
    534		if (mgids_changed) {
    535			rb_erase(&group->node, &group->port->table);
    536			is_mgid0 = !memcmp(&mgid0, &group->rec.mgid,
    537					   sizeof(mgid0));
    538			mcast_insert(group->port, group, is_mgid0);
    539		}
    540		spin_unlock_irq(&group->port->lock);
    541	}
    542	mcast_work_handler(&group->work);
    543}
    544
    545static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
    546			  void *context)
    547{
    548	struct mcast_group *group = context;
    549
    550	if (status && group->retries > 0 &&
    551	    !send_leave(group, group->leave_state))
    552		group->retries--;
    553	else
    554		mcast_work_handler(&group->work);
    555}
    556
    557static struct mcast_group *acquire_group(struct mcast_port *port,
    558					 union ib_gid *mgid, gfp_t gfp_mask)
    559{
    560	struct mcast_group *group, *cur_group;
    561	unsigned long flags;
    562	int is_mgid0;
    563
    564	is_mgid0 = !memcmp(&mgid0, mgid, sizeof mgid0);
    565	if (!is_mgid0) {
    566		spin_lock_irqsave(&port->lock, flags);
    567		group = mcast_find(port, mgid);
    568		if (group)
    569			goto found;
    570		spin_unlock_irqrestore(&port->lock, flags);
    571	}
    572
    573	group = kzalloc(sizeof *group, gfp_mask);
    574	if (!group)
    575		return NULL;
    576
    577	group->retries = 3;
    578	group->port = port;
    579	group->rec.mgid = *mgid;
    580	group->pkey_index = MCAST_INVALID_PKEY_INDEX;
    581	INIT_LIST_HEAD(&group->pending_list);
    582	INIT_LIST_HEAD(&group->active_list);
    583	INIT_WORK(&group->work, mcast_work_handler);
    584	spin_lock_init(&group->lock);
    585
    586	spin_lock_irqsave(&port->lock, flags);
    587	cur_group = mcast_insert(port, group, is_mgid0);
    588	if (cur_group) {
    589		kfree(group);
    590		group = cur_group;
    591	} else
    592		refcount_inc(&port->refcount);
    593found:
    594	atomic_inc(&group->refcount);
    595	spin_unlock_irqrestore(&port->lock, flags);
    596	return group;
    597}
    598
    599/*
    600 * We serialize all join requests to a single group to make our lives much
    601 * easier.  Otherwise, two users could try to join the same group
    602 * simultaneously, with different configurations, one could leave while the
    603 * join is in progress, etc., which makes locking around error recovery
    604 * difficult.
    605 */
    606struct ib_sa_multicast *
    607ib_sa_join_multicast(struct ib_sa_client *client,
    608		     struct ib_device *device, u32 port_num,
    609		     struct ib_sa_mcmember_rec *rec,
    610		     ib_sa_comp_mask comp_mask, gfp_t gfp_mask,
    611		     int (*callback)(int status,
    612				     struct ib_sa_multicast *multicast),
    613		     void *context)
    614{
    615	struct mcast_device *dev;
    616	struct mcast_member *member;
    617	struct ib_sa_multicast *multicast;
    618	int ret;
    619
    620	dev = ib_get_client_data(device, &mcast_client);
    621	if (!dev)
    622		return ERR_PTR(-ENODEV);
    623
    624	member = kmalloc(sizeof *member, gfp_mask);
    625	if (!member)
    626		return ERR_PTR(-ENOMEM);
    627
    628	ib_sa_client_get(client);
    629	member->client = client;
    630	member->multicast.rec = *rec;
    631	member->multicast.comp_mask = comp_mask;
    632	member->multicast.callback = callback;
    633	member->multicast.context = context;
    634	init_completion(&member->comp);
    635	refcount_set(&member->refcount, 1);
    636	member->state = MCAST_JOINING;
    637
    638	member->group = acquire_group(&dev->port[port_num - dev->start_port],
    639				      &rec->mgid, gfp_mask);
    640	if (!member->group) {
    641		ret = -ENOMEM;
    642		goto err;
    643	}
    644
    645	/*
    646	 * The user will get the multicast structure in their callback.  They
    647	 * could then free the multicast structure before we can return from
    648	 * this routine.  So we save the pointer to return before queuing
    649	 * any callback.
    650	 */
    651	multicast = &member->multicast;
    652	queue_join(member);
    653	return multicast;
    654
    655err:
    656	ib_sa_client_put(client);
    657	kfree(member);
    658	return ERR_PTR(ret);
    659}
    660EXPORT_SYMBOL(ib_sa_join_multicast);
    661
    662void ib_sa_free_multicast(struct ib_sa_multicast *multicast)
    663{
    664	struct mcast_member *member;
    665	struct mcast_group *group;
    666
    667	member = container_of(multicast, struct mcast_member, multicast);
    668	group = member->group;
    669
    670	spin_lock_irq(&group->lock);
    671	if (member->state == MCAST_MEMBER)
    672		adjust_membership(group, multicast->rec.join_state, -1);
    673
    674	list_del_init(&member->list);
    675
    676	if (group->state == MCAST_IDLE) {
    677		group->state = MCAST_BUSY;
    678		spin_unlock_irq(&group->lock);
    679		/* Continue to hold reference on group until callback */
    680		queue_work(mcast_wq, &group->work);
    681	} else {
    682		spin_unlock_irq(&group->lock);
    683		release_group(group);
    684	}
    685
    686	deref_member(member);
    687	wait_for_completion(&member->comp);
    688	ib_sa_client_put(member->client);
    689	kfree(member);
    690}
    691EXPORT_SYMBOL(ib_sa_free_multicast);
    692
    693int ib_sa_get_mcmember_rec(struct ib_device *device, u32 port_num,
    694			   union ib_gid *mgid, struct ib_sa_mcmember_rec *rec)
    695{
    696	struct mcast_device *dev;
    697	struct mcast_port *port;
    698	struct mcast_group *group;
    699	unsigned long flags;
    700	int ret = 0;
    701
    702	dev = ib_get_client_data(device, &mcast_client);
    703	if (!dev)
    704		return -ENODEV;
    705
    706	port = &dev->port[port_num - dev->start_port];
    707	spin_lock_irqsave(&port->lock, flags);
    708	group = mcast_find(port, mgid);
    709	if (group)
    710		*rec = group->rec;
    711	else
    712		ret = -EADDRNOTAVAIL;
    713	spin_unlock_irqrestore(&port->lock, flags);
    714
    715	return ret;
    716}
    717EXPORT_SYMBOL(ib_sa_get_mcmember_rec);
    718
    719/**
    720 * ib_init_ah_from_mcmember - Initialize AH attribute from multicast
    721 * member record and gid of the device.
    722 * @device:	RDMA device
    723 * @port_num:	Port of the rdma device to consider
    724 * @rec:	Multicast member record to use
    725 * @ndev:	Optional netdevice, applicable only for RoCE
    726 * @gid_type:	GID type to consider
    727 * @ah_attr:	AH attribute to fillup on successful completion
    728 *
    729 * ib_init_ah_from_mcmember() initializes AH attribute based on multicast
    730 * member record and other device properties. On success the caller is
    731 * responsible to call rdma_destroy_ah_attr on the ah_attr. Returns 0 on
    732 * success or appropriate error code.
    733 *
    734 */
    735int ib_init_ah_from_mcmember(struct ib_device *device, u32 port_num,
    736			     struct ib_sa_mcmember_rec *rec,
    737			     struct net_device *ndev,
    738			     enum ib_gid_type gid_type,
    739			     struct rdma_ah_attr *ah_attr)
    740{
    741	const struct ib_gid_attr *sgid_attr;
    742
    743	/* GID table is not based on the netdevice for IB link layer,
    744	 * so ignore ndev during search.
    745	 */
    746	if (rdma_protocol_ib(device, port_num))
    747		ndev = NULL;
    748	else if (!rdma_protocol_roce(device, port_num))
    749		return -EINVAL;
    750
    751	sgid_attr = rdma_find_gid_by_port(device, &rec->port_gid,
    752					  gid_type, port_num, ndev);
    753	if (IS_ERR(sgid_attr))
    754		return PTR_ERR(sgid_attr);
    755
    756	memset(ah_attr, 0, sizeof(*ah_attr));
    757	ah_attr->type = rdma_ah_find_type(device, port_num);
    758
    759	rdma_ah_set_dlid(ah_attr, be16_to_cpu(rec->mlid));
    760	rdma_ah_set_sl(ah_attr, rec->sl);
    761	rdma_ah_set_port_num(ah_attr, port_num);
    762	rdma_ah_set_static_rate(ah_attr, rec->rate);
    763	rdma_move_grh_sgid_attr(ah_attr, &rec->mgid,
    764				be32_to_cpu(rec->flow_label),
    765				rec->hop_limit,	rec->traffic_class,
    766				sgid_attr);
    767	return 0;
    768}
    769EXPORT_SYMBOL(ib_init_ah_from_mcmember);
    770
    771static void mcast_groups_event(struct mcast_port *port,
    772			       enum mcast_group_state state)
    773{
    774	struct mcast_group *group;
    775	struct rb_node *node;
    776	unsigned long flags;
    777
    778	spin_lock_irqsave(&port->lock, flags);
    779	for (node = rb_first(&port->table); node; node = rb_next(node)) {
    780		group = rb_entry(node, struct mcast_group, node);
    781		spin_lock(&group->lock);
    782		if (group->state == MCAST_IDLE) {
    783			atomic_inc(&group->refcount);
    784			queue_work(mcast_wq, &group->work);
    785		}
    786		if (group->state != MCAST_GROUP_ERROR)
    787			group->state = state;
    788		spin_unlock(&group->lock);
    789	}
    790	spin_unlock_irqrestore(&port->lock, flags);
    791}
    792
    793static void mcast_event_handler(struct ib_event_handler *handler,
    794				struct ib_event *event)
    795{
    796	struct mcast_device *dev;
    797	int index;
    798
    799	dev = container_of(handler, struct mcast_device, event_handler);
    800	if (!rdma_cap_ib_mcast(dev->device, event->element.port_num))
    801		return;
    802
    803	index = event->element.port_num - dev->start_port;
    804
    805	switch (event->event) {
    806	case IB_EVENT_PORT_ERR:
    807	case IB_EVENT_LID_CHANGE:
    808	case IB_EVENT_CLIENT_REREGISTER:
    809		mcast_groups_event(&dev->port[index], MCAST_GROUP_ERROR);
    810		break;
    811	case IB_EVENT_PKEY_CHANGE:
    812		mcast_groups_event(&dev->port[index], MCAST_PKEY_EVENT);
    813		break;
    814	default:
    815		break;
    816	}
    817}
    818
    819static int mcast_add_one(struct ib_device *device)
    820{
    821	struct mcast_device *dev;
    822	struct mcast_port *port;
    823	int i;
    824	int count = 0;
    825
    826	dev = kmalloc(struct_size(dev, port, device->phys_port_cnt),
    827		      GFP_KERNEL);
    828	if (!dev)
    829		return -ENOMEM;
    830
    831	dev->start_port = rdma_start_port(device);
    832	dev->end_port = rdma_end_port(device);
    833
    834	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
    835		if (!rdma_cap_ib_mcast(device, dev->start_port + i))
    836			continue;
    837		port = &dev->port[i];
    838		port->dev = dev;
    839		port->port_num = dev->start_port + i;
    840		spin_lock_init(&port->lock);
    841		port->table = RB_ROOT;
    842		init_completion(&port->comp);
    843		refcount_set(&port->refcount, 1);
    844		++count;
    845	}
    846
    847	if (!count) {
    848		kfree(dev);
    849		return -EOPNOTSUPP;
    850	}
    851
    852	dev->device = device;
    853	ib_set_client_data(device, &mcast_client, dev);
    854
    855	INIT_IB_EVENT_HANDLER(&dev->event_handler, device, mcast_event_handler);
    856	ib_register_event_handler(&dev->event_handler);
    857	return 0;
    858}
    859
    860static void mcast_remove_one(struct ib_device *device, void *client_data)
    861{
    862	struct mcast_device *dev = client_data;
    863	struct mcast_port *port;
    864	int i;
    865
    866	ib_unregister_event_handler(&dev->event_handler);
    867	flush_workqueue(mcast_wq);
    868
    869	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
    870		if (rdma_cap_ib_mcast(device, dev->start_port + i)) {
    871			port = &dev->port[i];
    872			deref_port(port);
    873			wait_for_completion(&port->comp);
    874		}
    875	}
    876
    877	kfree(dev);
    878}
    879
    880int mcast_init(void)
    881{
    882	int ret;
    883
    884	mcast_wq = alloc_ordered_workqueue("ib_mcast", WQ_MEM_RECLAIM);
    885	if (!mcast_wq)
    886		return -ENOMEM;
    887
    888	ib_sa_register_client(&sa_client);
    889
    890	ret = ib_register_client(&mcast_client);
    891	if (ret)
    892		goto err;
    893	return 0;
    894
    895err:
    896	ib_sa_unregister_client(&sa_client);
    897	destroy_workqueue(mcast_wq);
    898	return ret;
    899}
    900
    901void mcast_cleanup(void)
    902{
    903	ib_unregister_client(&mcast_client);
    904	ib_sa_unregister_client(&sa_client);
    905	destroy_workqueue(mcast_wq);
    906}