cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

pm_netlink.c (58602B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Multipath TCP
      3 *
      4 * Copyright (c) 2020, Red Hat, Inc.
      5 */
      6
      7#define pr_fmt(fmt) "MPTCP: " fmt
      8
      9#include <linux/inet.h>
     10#include <linux/kernel.h>
     11#include <net/tcp.h>
     12#include <net/netns/generic.h>
     13#include <net/mptcp.h>
     14#include <net/genetlink.h>
     15#include <uapi/linux/mptcp.h>
     16
     17#include "protocol.h"
     18#include "mib.h"
     19
     20/* forward declaration */
     21static struct genl_family mptcp_genl_family;
     22
     23static int pm_nl_pernet_id;
     24
     25struct mptcp_pm_add_entry {
     26	struct list_head	list;
     27	struct mptcp_addr_info	addr;
     28	struct timer_list	add_timer;
     29	struct mptcp_sock	*sock;
     30	u8			retrans_times;
     31};
     32
     33struct pm_nl_pernet {
     34	/* protects pernet updates */
     35	spinlock_t		lock;
     36	struct list_head	local_addr_list;
     37	unsigned int		addrs;
     38	unsigned int		stale_loss_cnt;
     39	unsigned int		add_addr_signal_max;
     40	unsigned int		add_addr_accept_max;
     41	unsigned int		local_addr_max;
     42	unsigned int		subflows_max;
     43	unsigned int		next_id;
     44	DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
     45};
     46
     47#define MPTCP_PM_ADDR_MAX	8
     48#define ADD_ADDR_RETRANS_MAX	3
     49
     50static struct pm_nl_pernet *pm_nl_get_pernet(const struct net *net)
     51{
     52	return net_generic(net, pm_nl_pernet_id);
     53}
     54
     55static struct pm_nl_pernet *
     56pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk)
     57{
     58	return pm_nl_get_pernet(sock_net((struct sock *)msk));
     59}
     60
     61bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
     62			   const struct mptcp_addr_info *b, bool use_port)
     63{
     64	bool addr_equals = false;
     65
     66	if (a->family == b->family) {
     67		if (a->family == AF_INET)
     68			addr_equals = a->addr.s_addr == b->addr.s_addr;
     69#if IS_ENABLED(CONFIG_MPTCP_IPV6)
     70		else
     71			addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
     72	} else if (a->family == AF_INET) {
     73		if (ipv6_addr_v4mapped(&b->addr6))
     74			addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
     75	} else if (b->family == AF_INET) {
     76		if (ipv6_addr_v4mapped(&a->addr6))
     77			addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
     78#endif
     79	}
     80
     81	if (!addr_equals)
     82		return false;
     83	if (!use_port)
     84		return true;
     85
     86	return a->port == b->port;
     87}
     88
     89static void local_address(const struct sock_common *skc,
     90			  struct mptcp_addr_info *addr)
     91{
     92	addr->family = skc->skc_family;
     93	addr->port = htons(skc->skc_num);
     94	if (addr->family == AF_INET)
     95		addr->addr.s_addr = skc->skc_rcv_saddr;
     96#if IS_ENABLED(CONFIG_MPTCP_IPV6)
     97	else if (addr->family == AF_INET6)
     98		addr->addr6 = skc->skc_v6_rcv_saddr;
     99#endif
    100}
    101
    102static void remote_address(const struct sock_common *skc,
    103			   struct mptcp_addr_info *addr)
    104{
    105	addr->family = skc->skc_family;
    106	addr->port = skc->skc_dport;
    107	if (addr->family == AF_INET)
    108		addr->addr.s_addr = skc->skc_daddr;
    109#if IS_ENABLED(CONFIG_MPTCP_IPV6)
    110	else if (addr->family == AF_INET6)
    111		addr->addr6 = skc->skc_v6_daddr;
    112#endif
    113}
    114
    115static bool lookup_subflow_by_saddr(const struct list_head *list,
    116				    const struct mptcp_addr_info *saddr)
    117{
    118	struct mptcp_subflow_context *subflow;
    119	struct mptcp_addr_info cur;
    120	struct sock_common *skc;
    121
    122	list_for_each_entry(subflow, list, node) {
    123		skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
    124
    125		local_address(skc, &cur);
    126		if (mptcp_addresses_equal(&cur, saddr, saddr->port))
    127			return true;
    128	}
    129
    130	return false;
    131}
    132
    133static bool lookup_subflow_by_daddr(const struct list_head *list,
    134				    const struct mptcp_addr_info *daddr)
    135{
    136	struct mptcp_subflow_context *subflow;
    137	struct mptcp_addr_info cur;
    138	struct sock_common *skc;
    139
    140	list_for_each_entry(subflow, list, node) {
    141		skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
    142
    143		remote_address(skc, &cur);
    144		if (mptcp_addresses_equal(&cur, daddr, daddr->port))
    145			return true;
    146	}
    147
    148	return false;
    149}
    150
    151static struct mptcp_pm_addr_entry *
    152select_local_address(const struct pm_nl_pernet *pernet,
    153		     const struct mptcp_sock *msk)
    154{
    155	const struct sock *sk = (const struct sock *)msk;
    156	struct mptcp_pm_addr_entry *entry, *ret = NULL;
    157
    158	msk_owned_by_me(msk);
    159
    160	rcu_read_lock();
    161	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
    162		if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
    163			continue;
    164
    165		if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
    166			continue;
    167
    168		if (entry->addr.family != sk->sk_family) {
    169#if IS_ENABLED(CONFIG_MPTCP_IPV6)
    170			if ((entry->addr.family == AF_INET &&
    171			     !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
    172			    (sk->sk_family == AF_INET &&
    173			     !ipv6_addr_v4mapped(&entry->addr.addr6)))
    174#endif
    175				continue;
    176		}
    177
    178		ret = entry;
    179		break;
    180	}
    181	rcu_read_unlock();
    182	return ret;
    183}
    184
    185static struct mptcp_pm_addr_entry *
    186select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk)
    187{
    188	struct mptcp_pm_addr_entry *entry, *ret = NULL;
    189
    190	rcu_read_lock();
    191	/* do not keep any additional per socket state, just signal
    192	 * the address list in order.
    193	 * Note: removal from the local address list during the msk life-cycle
    194	 * can lead to additional addresses not being announced.
    195	 */
    196	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
    197		if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
    198			continue;
    199
    200		if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
    201			continue;
    202
    203		ret = entry;
    204		break;
    205	}
    206	rcu_read_unlock();
    207	return ret;
    208}
    209
    210unsigned int mptcp_pm_get_add_addr_signal_max(const struct mptcp_sock *msk)
    211{
    212	const struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
    213
    214	return READ_ONCE(pernet->add_addr_signal_max);
    215}
    216EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
    217
    218unsigned int mptcp_pm_get_add_addr_accept_max(const struct mptcp_sock *msk)
    219{
    220	struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
    221
    222	return READ_ONCE(pernet->add_addr_accept_max);
    223}
    224EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
    225
    226unsigned int mptcp_pm_get_subflows_max(const struct mptcp_sock *msk)
    227{
    228	struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
    229
    230	return READ_ONCE(pernet->subflows_max);
    231}
    232EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
    233
    234unsigned int mptcp_pm_get_local_addr_max(const struct mptcp_sock *msk)
    235{
    236	struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
    237
    238	return READ_ONCE(pernet->local_addr_max);
    239}
    240EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
    241
    242bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk)
    243{
    244	struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
    245
    246	if (msk->pm.subflows == mptcp_pm_get_subflows_max(msk) ||
    247	    (find_next_and_bit(pernet->id_bitmap, msk->pm.id_avail_bitmap,
    248			       MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) {
    249		WRITE_ONCE(msk->pm.work_pending, false);
    250		return false;
    251	}
    252	return true;
    253}
    254
    255struct mptcp_pm_add_entry *
    256mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk,
    257				const struct mptcp_addr_info *addr)
    258{
    259	struct mptcp_pm_add_entry *entry;
    260
    261	lockdep_assert_held(&msk->pm.lock);
    262
    263	list_for_each_entry(entry, &msk->pm.anno_list, list) {
    264		if (mptcp_addresses_equal(&entry->addr, addr, true))
    265			return entry;
    266	}
    267
    268	return NULL;
    269}
    270
    271bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
    272{
    273	struct mptcp_pm_add_entry *entry;
    274	struct mptcp_addr_info saddr;
    275	bool ret = false;
    276
    277	local_address((struct sock_common *)sk, &saddr);
    278
    279	spin_lock_bh(&msk->pm.lock);
    280	list_for_each_entry(entry, &msk->pm.anno_list, list) {
    281		if (mptcp_addresses_equal(&entry->addr, &saddr, true)) {
    282			ret = true;
    283			goto out;
    284		}
    285	}
    286
    287out:
    288	spin_unlock_bh(&msk->pm.lock);
    289	return ret;
    290}
    291
    292static void mptcp_pm_add_timer(struct timer_list *timer)
    293{
    294	struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
    295	struct mptcp_sock *msk = entry->sock;
    296	struct sock *sk = (struct sock *)msk;
    297
    298	pr_debug("msk=%p", msk);
    299
    300	if (!msk)
    301		return;
    302
    303	if (inet_sk_state_load(sk) == TCP_CLOSE)
    304		return;
    305
    306	if (!entry->addr.id)
    307		return;
    308
    309	if (mptcp_pm_should_add_signal_addr(msk)) {
    310		sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
    311		goto out;
    312	}
    313
    314	spin_lock_bh(&msk->pm.lock);
    315
    316	if (!mptcp_pm_should_add_signal_addr(msk)) {
    317		pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
    318		mptcp_pm_announce_addr(msk, &entry->addr, false);
    319		mptcp_pm_add_addr_send_ack(msk);
    320		entry->retrans_times++;
    321	}
    322
    323	if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
    324		sk_reset_timer(sk, timer,
    325			       jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
    326
    327	spin_unlock_bh(&msk->pm.lock);
    328
    329	if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
    330		mptcp_pm_subflow_established(msk);
    331
    332out:
    333	__sock_put(sk);
    334}
    335
    336struct mptcp_pm_add_entry *
    337mptcp_pm_del_add_timer(struct mptcp_sock *msk,
    338		       const struct mptcp_addr_info *addr, bool check_id)
    339{
    340	struct mptcp_pm_add_entry *entry;
    341	struct sock *sk = (struct sock *)msk;
    342
    343	spin_lock_bh(&msk->pm.lock);
    344	entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
    345	if (entry && (!check_id || entry->addr.id == addr->id))
    346		entry->retrans_times = ADD_ADDR_RETRANS_MAX;
    347	spin_unlock_bh(&msk->pm.lock);
    348
    349	if (entry && (!check_id || entry->addr.id == addr->id))
    350		sk_stop_timer_sync(sk, &entry->add_timer);
    351
    352	return entry;
    353}
    354
    355bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
    356			      const struct mptcp_pm_addr_entry *entry)
    357{
    358	struct mptcp_pm_add_entry *add_entry = NULL;
    359	struct sock *sk = (struct sock *)msk;
    360	struct net *net = sock_net(sk);
    361
    362	lockdep_assert_held(&msk->pm.lock);
    363
    364	add_entry = mptcp_lookup_anno_list_by_saddr(msk, &entry->addr);
    365
    366	if (add_entry) {
    367		if (mptcp_pm_is_kernel(msk))
    368			return false;
    369
    370		sk_reset_timer(sk, &add_entry->add_timer,
    371			       jiffies + mptcp_get_add_addr_timeout(net));
    372		return true;
    373	}
    374
    375	add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
    376	if (!add_entry)
    377		return false;
    378
    379	list_add(&add_entry->list, &msk->pm.anno_list);
    380
    381	add_entry->addr = entry->addr;
    382	add_entry->sock = msk;
    383	add_entry->retrans_times = 0;
    384
    385	timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
    386	sk_reset_timer(sk, &add_entry->add_timer,
    387		       jiffies + mptcp_get_add_addr_timeout(net));
    388
    389	return true;
    390}
    391
    392void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
    393{
    394	struct mptcp_pm_add_entry *entry, *tmp;
    395	struct sock *sk = (struct sock *)msk;
    396	LIST_HEAD(free_list);
    397
    398	pr_debug("msk=%p", msk);
    399
    400	spin_lock_bh(&msk->pm.lock);
    401	list_splice_init(&msk->pm.anno_list, &free_list);
    402	spin_unlock_bh(&msk->pm.lock);
    403
    404	list_for_each_entry_safe(entry, tmp, &free_list, list) {
    405		sk_stop_timer_sync(sk, &entry->add_timer);
    406		kfree(entry);
    407	}
    408}
    409
    410static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned int nr,
    411				  const struct mptcp_addr_info *addr)
    412{
    413	int i;
    414
    415	for (i = 0; i < nr; i++) {
    416		if (mptcp_addresses_equal(&addrs[i], addr, addr->port))
    417			return true;
    418	}
    419
    420	return false;
    421}
    422
    423/* Fill all the remote addresses into the array addrs[],
    424 * and return the array size.
    425 */
    426static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh,
    427					      struct mptcp_addr_info *addrs)
    428{
    429	bool deny_id0 = READ_ONCE(msk->pm.remote_deny_join_id0);
    430	struct sock *sk = (struct sock *)msk, *ssk;
    431	struct mptcp_subflow_context *subflow;
    432	struct mptcp_addr_info remote = { 0 };
    433	unsigned int subflows_max;
    434	int i = 0;
    435
    436	subflows_max = mptcp_pm_get_subflows_max(msk);
    437	remote_address((struct sock_common *)sk, &remote);
    438
    439	/* Non-fullmesh endpoint, fill in the single entry
    440	 * corresponding to the primary MPC subflow remote address
    441	 */
    442	if (!fullmesh) {
    443		if (deny_id0)
    444			return 0;
    445
    446		msk->pm.subflows++;
    447		addrs[i++] = remote;
    448	} else {
    449		mptcp_for_each_subflow(msk, subflow) {
    450			ssk = mptcp_subflow_tcp_sock(subflow);
    451			remote_address((struct sock_common *)ssk, &addrs[i]);
    452			if (deny_id0 && mptcp_addresses_equal(&addrs[i], &remote, false))
    453				continue;
    454
    455			if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
    456			    msk->pm.subflows < subflows_max) {
    457				msk->pm.subflows++;
    458				i++;
    459			}
    460		}
    461	}
    462
    463	return i;
    464}
    465
    466static struct mptcp_pm_addr_entry *
    467__lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
    468{
    469	struct mptcp_pm_addr_entry *entry;
    470
    471	list_for_each_entry(entry, &pernet->local_addr_list, list) {
    472		if (entry->addr.id == id)
    473			return entry;
    474	}
    475	return NULL;
    476}
    477
    478static struct mptcp_pm_addr_entry *
    479__lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info,
    480	      bool lookup_by_id)
    481{
    482	struct mptcp_pm_addr_entry *entry;
    483
    484	list_for_each_entry(entry, &pernet->local_addr_list, list) {
    485		if ((!lookup_by_id && mptcp_addresses_equal(&entry->addr, info, true)) ||
    486		    (lookup_by_id && entry->addr.id == info->id))
    487			return entry;
    488	}
    489	return NULL;
    490}
    491
    492static int
    493lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_info *addr)
    494{
    495	const struct mptcp_pm_addr_entry *entry;
    496	int ret = -1;
    497
    498	rcu_read_lock();
    499	list_for_each_entry(entry, &pernet->local_addr_list, list) {
    500		if (mptcp_addresses_equal(&entry->addr, addr, entry->addr.port)) {
    501			ret = entry->addr.id;
    502			break;
    503		}
    504	}
    505	rcu_read_unlock();
    506	return ret;
    507}
    508
    509static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
    510{
    511	struct sock *sk = (struct sock *)msk;
    512	struct mptcp_pm_addr_entry *local;
    513	unsigned int add_addr_signal_max;
    514	unsigned int local_addr_max;
    515	struct pm_nl_pernet *pernet;
    516	unsigned int subflows_max;
    517
    518	pernet = pm_nl_get_pernet(sock_net(sk));
    519
    520	add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
    521	local_addr_max = mptcp_pm_get_local_addr_max(msk);
    522	subflows_max = mptcp_pm_get_subflows_max(msk);
    523
    524	/* do lazy endpoint usage accounting for the MPC subflows */
    525	if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) {
    526		struct mptcp_addr_info mpc_addr;
    527		int mpc_id;
    528
    529		local_address((struct sock_common *)msk->first, &mpc_addr);
    530		mpc_id = lookup_id_by_addr(pernet, &mpc_addr);
    531		if (mpc_id >= 0)
    532			__clear_bit(mpc_id, msk->pm.id_avail_bitmap);
    533
    534		msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED);
    535	}
    536
    537	pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
    538		 msk->pm.local_addr_used, local_addr_max,
    539		 msk->pm.add_addr_signaled, add_addr_signal_max,
    540		 msk->pm.subflows, subflows_max);
    541
    542	/* check first for announce */
    543	if (msk->pm.add_addr_signaled < add_addr_signal_max) {
    544		local = select_signal_address(pernet, msk);
    545
    546		/* due to racing events on both ends we can reach here while
    547		 * previous add address is still running: if we invoke now
    548		 * mptcp_pm_announce_addr(), that will fail and the
    549		 * corresponding id will be marked as used.
    550		 * Instead let the PM machinery reschedule us when the
    551		 * current address announce will be completed.
    552		 */
    553		if (msk->pm.addr_signal & BIT(MPTCP_ADD_ADDR_SIGNAL))
    554			return;
    555
    556		if (local) {
    557			if (mptcp_pm_alloc_anno_list(msk, local)) {
    558				__clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
    559				msk->pm.add_addr_signaled++;
    560				mptcp_pm_announce_addr(msk, &local->addr, false);
    561				mptcp_pm_nl_addr_send_ack(msk);
    562			}
    563		}
    564	}
    565
    566	/* check if should create a new subflow */
    567	while (msk->pm.local_addr_used < local_addr_max &&
    568	       msk->pm.subflows < subflows_max) {
    569		struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
    570		bool fullmesh;
    571		int i, nr;
    572
    573		local = select_local_address(pernet, msk);
    574		if (!local)
    575			break;
    576
    577		fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH);
    578
    579		msk->pm.local_addr_used++;
    580		nr = fill_remote_addresses_vec(msk, fullmesh, addrs);
    581		if (nr)
    582			__clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
    583		spin_unlock_bh(&msk->pm.lock);
    584		for (i = 0; i < nr; i++)
    585			__mptcp_subflow_connect(sk, &local->addr, &addrs[i]);
    586		spin_lock_bh(&msk->pm.lock);
    587	}
    588	mptcp_pm_nl_check_work_pending(msk);
    589}
    590
    591static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
    592{
    593	mptcp_pm_create_subflow_or_signal_addr(msk);
    594}
    595
    596static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
    597{
    598	mptcp_pm_create_subflow_or_signal_addr(msk);
    599}
    600
    601/* Fill all the local addresses into the array addrs[],
    602 * and return the array size.
    603 */
    604static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
    605					     struct mptcp_addr_info *addrs)
    606{
    607	struct sock *sk = (struct sock *)msk;
    608	struct mptcp_pm_addr_entry *entry;
    609	struct mptcp_addr_info local;
    610	struct pm_nl_pernet *pernet;
    611	unsigned int subflows_max;
    612	int i = 0;
    613
    614	pernet = pm_nl_get_pernet_from_msk(msk);
    615	subflows_max = mptcp_pm_get_subflows_max(msk);
    616
    617	rcu_read_lock();
    618	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
    619		if (!(entry->flags & MPTCP_PM_ADDR_FLAG_FULLMESH))
    620			continue;
    621
    622		if (entry->addr.family != sk->sk_family) {
    623#if IS_ENABLED(CONFIG_MPTCP_IPV6)
    624			if ((entry->addr.family == AF_INET &&
    625			     !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
    626			    (sk->sk_family == AF_INET &&
    627			     !ipv6_addr_v4mapped(&entry->addr.addr6)))
    628#endif
    629				continue;
    630		}
    631
    632		if (msk->pm.subflows < subflows_max) {
    633			msk->pm.subflows++;
    634			addrs[i++] = entry->addr;
    635		}
    636	}
    637	rcu_read_unlock();
    638
    639	/* If the array is empty, fill in the single
    640	 * 'IPADDRANY' local address
    641	 */
    642	if (!i) {
    643		memset(&local, 0, sizeof(local));
    644		local.family = msk->pm.remote.family;
    645
    646		msk->pm.subflows++;
    647		addrs[i++] = local;
    648	}
    649
    650	return i;
    651}
    652
    653static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
    654{
    655	struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
    656	struct sock *sk = (struct sock *)msk;
    657	unsigned int add_addr_accept_max;
    658	struct mptcp_addr_info remote;
    659	unsigned int subflows_max;
    660	int i, nr;
    661
    662	add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
    663	subflows_max = mptcp_pm_get_subflows_max(msk);
    664
    665	pr_debug("accepted %d:%d remote family %d",
    666		 msk->pm.add_addr_accepted, add_addr_accept_max,
    667		 msk->pm.remote.family);
    668
    669	remote = msk->pm.remote;
    670	mptcp_pm_announce_addr(msk, &remote, true);
    671	mptcp_pm_nl_addr_send_ack(msk);
    672
    673	if (lookup_subflow_by_daddr(&msk->conn_list, &remote))
    674		return;
    675
    676	/* pick id 0 port, if none is provided the remote address */
    677	if (!remote.port)
    678		remote.port = sk->sk_dport;
    679
    680	/* connect to the specified remote address, using whatever
    681	 * local address the routing configuration will pick.
    682	 */
    683	nr = fill_local_addresses_vec(msk, addrs);
    684
    685	msk->pm.add_addr_accepted++;
    686	if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
    687	    msk->pm.subflows >= subflows_max)
    688		WRITE_ONCE(msk->pm.accept_addr, false);
    689
    690	spin_unlock_bh(&msk->pm.lock);
    691	for (i = 0; i < nr; i++)
    692		__mptcp_subflow_connect(sk, &addrs[i], &remote);
    693	spin_lock_bh(&msk->pm.lock);
    694}
    695
    696void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
    697{
    698	struct mptcp_subflow_context *subflow;
    699
    700	msk_owned_by_me(msk);
    701	lockdep_assert_held(&msk->pm.lock);
    702
    703	if (!mptcp_pm_should_add_signal(msk) &&
    704	    !mptcp_pm_should_rm_signal(msk))
    705		return;
    706
    707	subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
    708	if (subflow) {
    709		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
    710
    711		spin_unlock_bh(&msk->pm.lock);
    712		pr_debug("send ack for %s",
    713			 mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr");
    714
    715		mptcp_subflow_send_ack(ssk);
    716		spin_lock_bh(&msk->pm.lock);
    717	}
    718}
    719
    720int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
    721				 struct mptcp_addr_info *addr,
    722				 struct mptcp_addr_info *rem,
    723				 u8 bkup)
    724{
    725	struct mptcp_subflow_context *subflow;
    726
    727	pr_debug("bkup=%d", bkup);
    728
    729	mptcp_for_each_subflow(msk, subflow) {
    730		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
    731		struct mptcp_addr_info local, remote;
    732		bool slow;
    733
    734		local_address((struct sock_common *)ssk, &local);
    735		if (!mptcp_addresses_equal(&local, addr, addr->port))
    736			continue;
    737
    738		if (rem && rem->family != AF_UNSPEC) {
    739			remote_address((struct sock_common *)ssk, &remote);
    740			if (!mptcp_addresses_equal(&remote, rem, rem->port))
    741				continue;
    742		}
    743
    744		slow = lock_sock_fast(ssk);
    745		if (subflow->backup != bkup)
    746			msk->last_snd = NULL;
    747		subflow->backup = bkup;
    748		subflow->send_mp_prio = 1;
    749		subflow->request_bkup = bkup;
    750
    751		pr_debug("send ack for mp_prio");
    752		__mptcp_subflow_send_ack(ssk);
    753		unlock_sock_fast(ssk, slow);
    754
    755		return 0;
    756	}
    757
    758	return -EINVAL;
    759}
    760
    761static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
    762					   const struct mptcp_rm_list *rm_list,
    763					   enum linux_mptcp_mib_field rm_type)
    764{
    765	struct mptcp_subflow_context *subflow, *tmp;
    766	struct sock *sk = (struct sock *)msk;
    767	u8 i;
    768
    769	pr_debug("%s rm_list_nr %d",
    770		 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
    771
    772	msk_owned_by_me(msk);
    773
    774	if (sk->sk_state == TCP_LISTEN)
    775		return;
    776
    777	if (!rm_list->nr)
    778		return;
    779
    780	if (list_empty(&msk->conn_list))
    781		return;
    782
    783	for (i = 0; i < rm_list->nr; i++) {
    784		bool removed = false;
    785
    786		list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
    787			struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
    788			int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
    789			u8 id = subflow->local_id;
    790
    791			if (rm_type == MPTCP_MIB_RMADDR)
    792				id = subflow->remote_id;
    793
    794			if (rm_list->ids[i] != id)
    795				continue;
    796
    797			pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
    798				 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
    799				 i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
    800			spin_unlock_bh(&msk->pm.lock);
    801			mptcp_subflow_shutdown(sk, ssk, how);
    802
    803			/* the following takes care of updating the subflows counter */
    804			mptcp_close_ssk(sk, ssk, subflow);
    805			spin_lock_bh(&msk->pm.lock);
    806
    807			removed = true;
    808			__MPTCP_INC_STATS(sock_net(sk), rm_type);
    809		}
    810		if (rm_type == MPTCP_MIB_RMSUBFLOW)
    811			__set_bit(rm_list->ids[i], msk->pm.id_avail_bitmap);
    812		if (!removed)
    813			continue;
    814
    815		if (!mptcp_pm_is_kernel(msk))
    816			continue;
    817
    818		if (rm_type == MPTCP_MIB_RMADDR) {
    819			msk->pm.add_addr_accepted--;
    820			WRITE_ONCE(msk->pm.accept_addr, true);
    821		} else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
    822			msk->pm.local_addr_used--;
    823		}
    824	}
    825}
    826
    827static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
    828{
    829	mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
    830}
    831
    832void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
    833				     const struct mptcp_rm_list *rm_list)
    834{
    835	mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
    836}
    837
    838void mptcp_pm_nl_work(struct mptcp_sock *msk)
    839{
    840	struct mptcp_pm_data *pm = &msk->pm;
    841
    842	msk_owned_by_me(msk);
    843
    844	if (!(pm->status & MPTCP_PM_WORK_MASK))
    845		return;
    846
    847	spin_lock_bh(&msk->pm.lock);
    848
    849	pr_debug("msk=%p status=%x", msk, pm->status);
    850	if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
    851		pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
    852		mptcp_pm_nl_add_addr_received(msk);
    853	}
    854	if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
    855		pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
    856		mptcp_pm_nl_addr_send_ack(msk);
    857	}
    858	if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
    859		pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
    860		mptcp_pm_nl_rm_addr_received(msk);
    861	}
    862	if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
    863		pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
    864		mptcp_pm_nl_fully_established(msk);
    865	}
    866	if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
    867		pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
    868		mptcp_pm_nl_subflow_established(msk);
    869	}
    870
    871	spin_unlock_bh(&msk->pm.lock);
    872}
    873
    874static bool address_use_port(struct mptcp_pm_addr_entry *entry)
    875{
    876	return (entry->flags &
    877		(MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
    878		MPTCP_PM_ADDR_FLAG_SIGNAL;
    879}
    880
    881/* caller must ensure the RCU grace period is already elapsed */
    882static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)
    883{
    884	if (entry->lsk)
    885		sock_release(entry->lsk);
    886	kfree(entry);
    887}
    888
    889static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
    890					     struct mptcp_pm_addr_entry *entry)
    891{
    892	struct mptcp_pm_addr_entry *cur, *del_entry = NULL;
    893	unsigned int addr_max;
    894	int ret = -EINVAL;
    895
    896	spin_lock_bh(&pernet->lock);
    897	/* to keep the code simple, don't do IDR-like allocation for address ID,
    898	 * just bail when we exceed limits
    899	 */
    900	if (pernet->next_id == MPTCP_PM_MAX_ADDR_ID)
    901		pernet->next_id = 1;
    902	if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
    903		goto out;
    904	if (test_bit(entry->addr.id, pernet->id_bitmap))
    905		goto out;
    906
    907	/* do not insert duplicate address, differentiate on port only
    908	 * singled addresses
    909	 */
    910	list_for_each_entry(cur, &pernet->local_addr_list, list) {
    911		if (mptcp_addresses_equal(&cur->addr, &entry->addr,
    912					  address_use_port(entry) &&
    913					  address_use_port(cur))) {
    914			/* allow replacing the exiting endpoint only if such
    915			 * endpoint is an implicit one and the user-space
    916			 * did not provide an endpoint id
    917			 */
    918			if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT))
    919				goto out;
    920			if (entry->addr.id)
    921				goto out;
    922
    923			pernet->addrs--;
    924			entry->addr.id = cur->addr.id;
    925			list_del_rcu(&cur->list);
    926			del_entry = cur;
    927			break;
    928		}
    929	}
    930
    931	if (!entry->addr.id) {
    932find_next:
    933		entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
    934						    MPTCP_PM_MAX_ADDR_ID + 1,
    935						    pernet->next_id);
    936		if (!entry->addr.id && pernet->next_id != 1) {
    937			pernet->next_id = 1;
    938			goto find_next;
    939		}
    940	}
    941
    942	if (!entry->addr.id)
    943		goto out;
    944
    945	__set_bit(entry->addr.id, pernet->id_bitmap);
    946	if (entry->addr.id > pernet->next_id)
    947		pernet->next_id = entry->addr.id;
    948
    949	if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
    950		addr_max = pernet->add_addr_signal_max;
    951		WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
    952	}
    953	if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
    954		addr_max = pernet->local_addr_max;
    955		WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
    956	}
    957
    958	pernet->addrs++;
    959	list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
    960	ret = entry->addr.id;
    961
    962out:
    963	spin_unlock_bh(&pernet->lock);
    964
    965	/* just replaced an existing entry, free it */
    966	if (del_entry) {
    967		synchronize_rcu();
    968		__mptcp_pm_release_addr_entry(del_entry);
    969	}
    970	return ret;
    971}
    972
    973static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
    974					    struct mptcp_pm_addr_entry *entry)
    975{
    976	int addrlen = sizeof(struct sockaddr_in);
    977	struct sockaddr_storage addr;
    978	struct mptcp_sock *msk;
    979	struct socket *ssock;
    980	int backlog = 1024;
    981	int err;
    982
    983	err = sock_create_kern(sock_net(sk), entry->addr.family,
    984			       SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
    985	if (err)
    986		return err;
    987
    988	msk = mptcp_sk(entry->lsk->sk);
    989	if (!msk) {
    990		err = -EINVAL;
    991		goto out;
    992	}
    993
    994	ssock = __mptcp_nmpc_socket(msk);
    995	if (!ssock) {
    996		err = -EINVAL;
    997		goto out;
    998	}
    999
   1000	mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
   1001#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1002	if (entry->addr.family == AF_INET6)
   1003		addrlen = sizeof(struct sockaddr_in6);
   1004#endif
   1005	err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen);
   1006	if (err) {
   1007		pr_warn("kernel_bind error, err=%d", err);
   1008		goto out;
   1009	}
   1010
   1011	err = kernel_listen(ssock, backlog);
   1012	if (err) {
   1013		pr_warn("kernel_listen error, err=%d", err);
   1014		goto out;
   1015	}
   1016
   1017	return 0;
   1018
   1019out:
   1020	sock_release(entry->lsk);
   1021	return err;
   1022}
   1023
   1024int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
   1025{
   1026	struct mptcp_pm_addr_entry *entry;
   1027	struct mptcp_addr_info skc_local;
   1028	struct mptcp_addr_info msk_local;
   1029	struct pm_nl_pernet *pernet;
   1030	int ret = -1;
   1031
   1032	if (WARN_ON_ONCE(!msk))
   1033		return -1;
   1034
   1035	/* The 0 ID mapping is defined by the first subflow, copied into the msk
   1036	 * addr
   1037	 */
   1038	local_address((struct sock_common *)msk, &msk_local);
   1039	local_address((struct sock_common *)skc, &skc_local);
   1040	if (mptcp_addresses_equal(&msk_local, &skc_local, false))
   1041		return 0;
   1042
   1043	if (mptcp_pm_is_userspace(msk))
   1044		return mptcp_userspace_pm_get_local_id(msk, &skc_local);
   1045
   1046	pernet = pm_nl_get_pernet_from_msk(msk);
   1047
   1048	rcu_read_lock();
   1049	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
   1050		if (mptcp_addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
   1051			ret = entry->addr.id;
   1052			break;
   1053		}
   1054	}
   1055	rcu_read_unlock();
   1056	if (ret >= 0)
   1057		return ret;
   1058
   1059	/* address not found, add to local list */
   1060	entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
   1061	if (!entry)
   1062		return -ENOMEM;
   1063
   1064	entry->addr = skc_local;
   1065	entry->addr.id = 0;
   1066	entry->addr.port = 0;
   1067	entry->ifindex = 0;
   1068	entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
   1069	entry->lsk = NULL;
   1070	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
   1071	if (ret < 0)
   1072		kfree(entry);
   1073
   1074	return ret;
   1075}
   1076
   1077#define MPTCP_PM_CMD_GRP_OFFSET       0
   1078#define MPTCP_PM_EV_GRP_OFFSET        1
   1079
   1080static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
   1081	[MPTCP_PM_CMD_GRP_OFFSET]	= { .name = MPTCP_PM_CMD_GRP_NAME, },
   1082	[MPTCP_PM_EV_GRP_OFFSET]        = { .name = MPTCP_PM_EV_GRP_NAME,
   1083					    .flags = GENL_UNS_ADMIN_PERM,
   1084					  },
   1085};
   1086
   1087static const struct nla_policy
   1088mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
   1089	[MPTCP_PM_ADDR_ATTR_FAMILY]	= { .type	= NLA_U16,	},
   1090	[MPTCP_PM_ADDR_ATTR_ID]		= { .type	= NLA_U8,	},
   1091	[MPTCP_PM_ADDR_ATTR_ADDR4]	= { .type	= NLA_U32,	},
   1092	[MPTCP_PM_ADDR_ATTR_ADDR6]	=
   1093		NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
   1094	[MPTCP_PM_ADDR_ATTR_PORT]	= { .type	= NLA_U16	},
   1095	[MPTCP_PM_ADDR_ATTR_FLAGS]	= { .type	= NLA_U32	},
   1096	[MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type	= NLA_S32	},
   1097};
   1098
   1099static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
   1100	[MPTCP_PM_ATTR_ADDR]		=
   1101					NLA_POLICY_NESTED(mptcp_pm_addr_policy),
   1102	[MPTCP_PM_ATTR_RCV_ADD_ADDRS]	= { .type	= NLA_U32,	},
   1103	[MPTCP_PM_ATTR_SUBFLOWS]	= { .type	= NLA_U32,	},
   1104	[MPTCP_PM_ATTR_TOKEN]		= { .type	= NLA_U32,	},
   1105	[MPTCP_PM_ATTR_LOC_ID]		= { .type	= NLA_U8,	},
   1106	[MPTCP_PM_ATTR_ADDR_REMOTE]	=
   1107					NLA_POLICY_NESTED(mptcp_pm_addr_policy),
   1108};
   1109
   1110void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk)
   1111{
   1112	struct mptcp_subflow_context *iter, *subflow = mptcp_subflow_ctx(ssk);
   1113	struct sock *sk = (struct sock *)msk;
   1114	unsigned int active_max_loss_cnt;
   1115	struct net *net = sock_net(sk);
   1116	unsigned int stale_loss_cnt;
   1117	bool slow;
   1118
   1119	stale_loss_cnt = mptcp_stale_loss_cnt(net);
   1120	if (subflow->stale || !stale_loss_cnt || subflow->stale_count <= stale_loss_cnt)
   1121		return;
   1122
   1123	/* look for another available subflow not in loss state */
   1124	active_max_loss_cnt = max_t(int, stale_loss_cnt - 1, 1);
   1125	mptcp_for_each_subflow(msk, iter) {
   1126		if (iter != subflow && mptcp_subflow_active(iter) &&
   1127		    iter->stale_count < active_max_loss_cnt) {
   1128			/* we have some alternatives, try to mark this subflow as idle ...*/
   1129			slow = lock_sock_fast(ssk);
   1130			if (!tcp_rtx_and_write_queues_empty(ssk)) {
   1131				subflow->stale = 1;
   1132				__mptcp_retransmit_pending_data(sk);
   1133				MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE);
   1134			}
   1135			unlock_sock_fast(ssk, slow);
   1136
   1137			/* always try to push the pending data regarless of re-injections:
   1138			 * we can possibly use backup subflows now, and subflow selection
   1139			 * is cheap under the msk socket lock
   1140			 */
   1141			__mptcp_push_pending(sk, 0);
   1142			return;
   1143		}
   1144	}
   1145}
   1146
   1147static int mptcp_pm_family_to_addr(int family)
   1148{
   1149#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1150	if (family == AF_INET6)
   1151		return MPTCP_PM_ADDR_ATTR_ADDR6;
   1152#endif
   1153	return MPTCP_PM_ADDR_ATTR_ADDR4;
   1154}
   1155
   1156static int mptcp_pm_parse_pm_addr_attr(struct nlattr *tb[],
   1157				       const struct nlattr *attr,
   1158				       struct genl_info *info,
   1159				       struct mptcp_addr_info *addr,
   1160				       bool require_family)
   1161{
   1162	int err, addr_addr;
   1163
   1164	if (!attr) {
   1165		GENL_SET_ERR_MSG(info, "missing address info");
   1166		return -EINVAL;
   1167	}
   1168
   1169	/* no validation needed - was already done via nested policy */
   1170	err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
   1171					  mptcp_pm_addr_policy, info->extack);
   1172	if (err)
   1173		return err;
   1174
   1175	if (tb[MPTCP_PM_ADDR_ATTR_ID])
   1176		addr->id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
   1177
   1178	if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
   1179		if (!require_family)
   1180			return err;
   1181
   1182		NL_SET_ERR_MSG_ATTR(info->extack, attr,
   1183				    "missing family");
   1184		return -EINVAL;
   1185	}
   1186
   1187	addr->family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
   1188	if (addr->family != AF_INET
   1189#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1190	    && addr->family != AF_INET6
   1191#endif
   1192	    ) {
   1193		NL_SET_ERR_MSG_ATTR(info->extack, attr,
   1194				    "unknown address family");
   1195		return -EINVAL;
   1196	}
   1197	addr_addr = mptcp_pm_family_to_addr(addr->family);
   1198	if (!tb[addr_addr]) {
   1199		NL_SET_ERR_MSG_ATTR(info->extack, attr,
   1200				    "missing address data");
   1201		return -EINVAL;
   1202	}
   1203
   1204#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1205	if (addr->family == AF_INET6)
   1206		addr->addr6 = nla_get_in6_addr(tb[addr_addr]);
   1207	else
   1208#endif
   1209		addr->addr.s_addr = nla_get_in_addr(tb[addr_addr]);
   1210
   1211	if (tb[MPTCP_PM_ADDR_ATTR_PORT])
   1212		addr->port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
   1213
   1214	return err;
   1215}
   1216
   1217int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
   1218			struct mptcp_addr_info *addr)
   1219{
   1220	struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
   1221
   1222	memset(addr, 0, sizeof(*addr));
   1223
   1224	return mptcp_pm_parse_pm_addr_attr(tb, attr, info, addr, true);
   1225}
   1226
   1227int mptcp_pm_parse_entry(struct nlattr *attr, struct genl_info *info,
   1228			 bool require_family,
   1229			 struct mptcp_pm_addr_entry *entry)
   1230{
   1231	struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
   1232	int err;
   1233
   1234	memset(entry, 0, sizeof(*entry));
   1235
   1236	err = mptcp_pm_parse_pm_addr_attr(tb, attr, info, &entry->addr, require_family);
   1237	if (err)
   1238		return err;
   1239
   1240	if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
   1241		u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
   1242
   1243		entry->ifindex = val;
   1244	}
   1245
   1246	if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
   1247		entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
   1248
   1249	if (tb[MPTCP_PM_ADDR_ATTR_PORT])
   1250		entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
   1251
   1252	return 0;
   1253}
   1254
   1255static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
   1256{
   1257	return pm_nl_get_pernet(genl_info_net(info));
   1258}
   1259
   1260static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
   1261{
   1262	struct mptcp_sock *msk;
   1263	long s_slot = 0, s_num = 0;
   1264
   1265	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
   1266		struct sock *sk = (struct sock *)msk;
   1267
   1268		if (!READ_ONCE(msk->fully_established) ||
   1269		    mptcp_pm_is_userspace(msk))
   1270			goto next;
   1271
   1272		lock_sock(sk);
   1273		spin_lock_bh(&msk->pm.lock);
   1274		mptcp_pm_create_subflow_or_signal_addr(msk);
   1275		spin_unlock_bh(&msk->pm.lock);
   1276		release_sock(sk);
   1277
   1278next:
   1279		sock_put(sk);
   1280		cond_resched();
   1281	}
   1282
   1283	return 0;
   1284}
   1285
   1286static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
   1287{
   1288	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
   1289	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1290	struct mptcp_pm_addr_entry addr, *entry;
   1291	int ret;
   1292
   1293	ret = mptcp_pm_parse_entry(attr, info, true, &addr);
   1294	if (ret < 0)
   1295		return ret;
   1296
   1297	if (addr.addr.port && !(addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
   1298		GENL_SET_ERR_MSG(info, "flags must have signal when using port");
   1299		return -EINVAL;
   1300	}
   1301
   1302	if (addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL &&
   1303	    addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) {
   1304		GENL_SET_ERR_MSG(info, "flags mustn't have both signal and fullmesh");
   1305		return -EINVAL;
   1306	}
   1307
   1308	if (addr.flags & MPTCP_PM_ADDR_FLAG_IMPLICIT) {
   1309		GENL_SET_ERR_MSG(info, "can't create IMPLICIT endpoint");
   1310		return -EINVAL;
   1311	}
   1312
   1313	entry = kmalloc(sizeof(*entry), GFP_KERNEL);
   1314	if (!entry) {
   1315		GENL_SET_ERR_MSG(info, "can't allocate addr");
   1316		return -ENOMEM;
   1317	}
   1318
   1319	*entry = addr;
   1320	if (entry->addr.port) {
   1321		ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
   1322		if (ret) {
   1323			GENL_SET_ERR_MSG(info, "create listen socket error");
   1324			kfree(entry);
   1325			return ret;
   1326		}
   1327	}
   1328	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
   1329	if (ret < 0) {
   1330		GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
   1331		if (entry->lsk)
   1332			sock_release(entry->lsk);
   1333		kfree(entry);
   1334		return ret;
   1335	}
   1336
   1337	mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
   1338
   1339	return 0;
   1340}
   1341
   1342int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,
   1343					 u8 *flags, int *ifindex)
   1344{
   1345	struct mptcp_pm_addr_entry *entry;
   1346	struct sock *sk = (struct sock *)msk;
   1347	struct net *net = sock_net(sk);
   1348
   1349	*flags = 0;
   1350	*ifindex = 0;
   1351
   1352	if (id) {
   1353		if (mptcp_pm_is_userspace(msk))
   1354			return mptcp_userspace_pm_get_flags_and_ifindex_by_id(msk,
   1355									      id,
   1356									      flags,
   1357									      ifindex);
   1358
   1359		rcu_read_lock();
   1360		entry = __lookup_addr_by_id(pm_nl_get_pernet(net), id);
   1361		if (entry) {
   1362			*flags = entry->flags;
   1363			*ifindex = entry->ifindex;
   1364		}
   1365		rcu_read_unlock();
   1366	}
   1367
   1368	return 0;
   1369}
   1370
   1371static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
   1372				      const struct mptcp_addr_info *addr)
   1373{
   1374	struct mptcp_pm_add_entry *entry;
   1375
   1376	entry = mptcp_pm_del_add_timer(msk, addr, false);
   1377	if (entry) {
   1378		list_del(&entry->list);
   1379		kfree(entry);
   1380		return true;
   1381	}
   1382
   1383	return false;
   1384}
   1385
   1386static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
   1387				      const struct mptcp_addr_info *addr,
   1388				      bool force)
   1389{
   1390	struct mptcp_rm_list list = { .nr = 0 };
   1391	bool ret;
   1392
   1393	list.ids[list.nr++] = addr->id;
   1394
   1395	ret = remove_anno_list_by_saddr(msk, addr);
   1396	if (ret || force) {
   1397		spin_lock_bh(&msk->pm.lock);
   1398		mptcp_pm_remove_addr(msk, &list);
   1399		spin_unlock_bh(&msk->pm.lock);
   1400	}
   1401	return ret;
   1402}
   1403
   1404static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
   1405						   const struct mptcp_pm_addr_entry *entry)
   1406{
   1407	const struct mptcp_addr_info *addr = &entry->addr;
   1408	struct mptcp_rm_list list = { .nr = 0 };
   1409	long s_slot = 0, s_num = 0;
   1410	struct mptcp_sock *msk;
   1411
   1412	pr_debug("remove_id=%d", addr->id);
   1413
   1414	list.ids[list.nr++] = addr->id;
   1415
   1416	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
   1417		struct sock *sk = (struct sock *)msk;
   1418		bool remove_subflow;
   1419
   1420		if (mptcp_pm_is_userspace(msk))
   1421			goto next;
   1422
   1423		if (list_empty(&msk->conn_list)) {
   1424			mptcp_pm_remove_anno_addr(msk, addr, false);
   1425			goto next;
   1426		}
   1427
   1428		lock_sock(sk);
   1429		remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
   1430		mptcp_pm_remove_anno_addr(msk, addr, remove_subflow &&
   1431					  !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT));
   1432		if (remove_subflow)
   1433			mptcp_pm_remove_subflow(msk, &list);
   1434		release_sock(sk);
   1435
   1436next:
   1437		sock_put(sk);
   1438		cond_resched();
   1439	}
   1440
   1441	return 0;
   1442}
   1443
   1444static int mptcp_nl_remove_id_zero_address(struct net *net,
   1445					   struct mptcp_addr_info *addr)
   1446{
   1447	struct mptcp_rm_list list = { .nr = 0 };
   1448	long s_slot = 0, s_num = 0;
   1449	struct mptcp_sock *msk;
   1450
   1451	list.ids[list.nr++] = 0;
   1452
   1453	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
   1454		struct sock *sk = (struct sock *)msk;
   1455		struct mptcp_addr_info msk_local;
   1456
   1457		if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk))
   1458			goto next;
   1459
   1460		local_address((struct sock_common *)msk, &msk_local);
   1461		if (!mptcp_addresses_equal(&msk_local, addr, addr->port))
   1462			goto next;
   1463
   1464		lock_sock(sk);
   1465		spin_lock_bh(&msk->pm.lock);
   1466		mptcp_pm_remove_addr(msk, &list);
   1467		mptcp_pm_nl_rm_subflow_received(msk, &list);
   1468		spin_unlock_bh(&msk->pm.lock);
   1469		release_sock(sk);
   1470
   1471next:
   1472		sock_put(sk);
   1473		cond_resched();
   1474	}
   1475
   1476	return 0;
   1477}
   1478
   1479static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
   1480{
   1481	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
   1482	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1483	struct mptcp_pm_addr_entry addr, *entry;
   1484	unsigned int addr_max;
   1485	int ret;
   1486
   1487	ret = mptcp_pm_parse_entry(attr, info, false, &addr);
   1488	if (ret < 0)
   1489		return ret;
   1490
   1491	/* the zero id address is special: the first address used by the msk
   1492	 * always gets such an id, so different subflows can have different zero
   1493	 * id addresses. Additionally zero id is not accounted for in id_bitmap.
   1494	 * Let's use an 'mptcp_rm_list' instead of the common remove code.
   1495	 */
   1496	if (addr.addr.id == 0)
   1497		return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
   1498
   1499	spin_lock_bh(&pernet->lock);
   1500	entry = __lookup_addr_by_id(pernet, addr.addr.id);
   1501	if (!entry) {
   1502		GENL_SET_ERR_MSG(info, "address not found");
   1503		spin_unlock_bh(&pernet->lock);
   1504		return -EINVAL;
   1505	}
   1506	if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
   1507		addr_max = pernet->add_addr_signal_max;
   1508		WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
   1509	}
   1510	if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
   1511		addr_max = pernet->local_addr_max;
   1512		WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
   1513	}
   1514
   1515	pernet->addrs--;
   1516	list_del_rcu(&entry->list);
   1517	__clear_bit(entry->addr.id, pernet->id_bitmap);
   1518	spin_unlock_bh(&pernet->lock);
   1519
   1520	mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry);
   1521	synchronize_rcu();
   1522	__mptcp_pm_release_addr_entry(entry);
   1523
   1524	return ret;
   1525}
   1526
   1527void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
   1528					struct list_head *rm_list)
   1529{
   1530	struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
   1531	struct mptcp_pm_addr_entry *entry;
   1532
   1533	list_for_each_entry(entry, rm_list, list) {
   1534		if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
   1535		    slist.nr < MPTCP_RM_IDS_MAX)
   1536			slist.ids[slist.nr++] = entry->addr.id;
   1537
   1538		if (remove_anno_list_by_saddr(msk, &entry->addr) &&
   1539		    alist.nr < MPTCP_RM_IDS_MAX)
   1540			alist.ids[alist.nr++] = entry->addr.id;
   1541	}
   1542
   1543	if (alist.nr) {
   1544		spin_lock_bh(&msk->pm.lock);
   1545		mptcp_pm_remove_addr(msk, &alist);
   1546		spin_unlock_bh(&msk->pm.lock);
   1547	}
   1548	if (slist.nr)
   1549		mptcp_pm_remove_subflow(msk, &slist);
   1550}
   1551
   1552static void mptcp_nl_remove_addrs_list(struct net *net,
   1553				       struct list_head *rm_list)
   1554{
   1555	long s_slot = 0, s_num = 0;
   1556	struct mptcp_sock *msk;
   1557
   1558	if (list_empty(rm_list))
   1559		return;
   1560
   1561	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
   1562		struct sock *sk = (struct sock *)msk;
   1563
   1564		if (!mptcp_pm_is_userspace(msk)) {
   1565			lock_sock(sk);
   1566			mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
   1567			release_sock(sk);
   1568		}
   1569
   1570		sock_put(sk);
   1571		cond_resched();
   1572	}
   1573}
   1574
   1575/* caller must ensure the RCU grace period is already elapsed */
   1576static void __flush_addrs(struct list_head *list)
   1577{
   1578	while (!list_empty(list)) {
   1579		struct mptcp_pm_addr_entry *cur;
   1580
   1581		cur = list_entry(list->next,
   1582				 struct mptcp_pm_addr_entry, list);
   1583		list_del_rcu(&cur->list);
   1584		__mptcp_pm_release_addr_entry(cur);
   1585	}
   1586}
   1587
   1588static void __reset_counters(struct pm_nl_pernet *pernet)
   1589{
   1590	WRITE_ONCE(pernet->add_addr_signal_max, 0);
   1591	WRITE_ONCE(pernet->add_addr_accept_max, 0);
   1592	WRITE_ONCE(pernet->local_addr_max, 0);
   1593	pernet->addrs = 0;
   1594}
   1595
   1596static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
   1597{
   1598	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1599	LIST_HEAD(free_list);
   1600
   1601	spin_lock_bh(&pernet->lock);
   1602	list_splice_init(&pernet->local_addr_list, &free_list);
   1603	__reset_counters(pernet);
   1604	pernet->next_id = 1;
   1605	bitmap_zero(pernet->id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
   1606	spin_unlock_bh(&pernet->lock);
   1607	mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
   1608	synchronize_rcu();
   1609	__flush_addrs(&free_list);
   1610	return 0;
   1611}
   1612
   1613static int mptcp_nl_fill_addr(struct sk_buff *skb,
   1614			      struct mptcp_pm_addr_entry *entry)
   1615{
   1616	struct mptcp_addr_info *addr = &entry->addr;
   1617	struct nlattr *attr;
   1618
   1619	attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
   1620	if (!attr)
   1621		return -EMSGSIZE;
   1622
   1623	if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
   1624		goto nla_put_failure;
   1625	if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
   1626		goto nla_put_failure;
   1627	if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
   1628		goto nla_put_failure;
   1629	if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
   1630		goto nla_put_failure;
   1631	if (entry->ifindex &&
   1632	    nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
   1633		goto nla_put_failure;
   1634
   1635	if (addr->family == AF_INET &&
   1636	    nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
   1637			    addr->addr.s_addr))
   1638		goto nla_put_failure;
   1639#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1640	else if (addr->family == AF_INET6 &&
   1641		 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
   1642		goto nla_put_failure;
   1643#endif
   1644	nla_nest_end(skb, attr);
   1645	return 0;
   1646
   1647nla_put_failure:
   1648	nla_nest_cancel(skb, attr);
   1649	return -EMSGSIZE;
   1650}
   1651
   1652static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
   1653{
   1654	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
   1655	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1656	struct mptcp_pm_addr_entry addr, *entry;
   1657	struct sk_buff *msg;
   1658	void *reply;
   1659	int ret;
   1660
   1661	ret = mptcp_pm_parse_entry(attr, info, false, &addr);
   1662	if (ret < 0)
   1663		return ret;
   1664
   1665	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
   1666	if (!msg)
   1667		return -ENOMEM;
   1668
   1669	reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
   1670				  info->genlhdr->cmd);
   1671	if (!reply) {
   1672		GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
   1673		ret = -EMSGSIZE;
   1674		goto fail;
   1675	}
   1676
   1677	spin_lock_bh(&pernet->lock);
   1678	entry = __lookup_addr_by_id(pernet, addr.addr.id);
   1679	if (!entry) {
   1680		GENL_SET_ERR_MSG(info, "address not found");
   1681		ret = -EINVAL;
   1682		goto unlock_fail;
   1683	}
   1684
   1685	ret = mptcp_nl_fill_addr(msg, entry);
   1686	if (ret)
   1687		goto unlock_fail;
   1688
   1689	genlmsg_end(msg, reply);
   1690	ret = genlmsg_reply(msg, info);
   1691	spin_unlock_bh(&pernet->lock);
   1692	return ret;
   1693
   1694unlock_fail:
   1695	spin_unlock_bh(&pernet->lock);
   1696
   1697fail:
   1698	nlmsg_free(msg);
   1699	return ret;
   1700}
   1701
   1702static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
   1703				   struct netlink_callback *cb)
   1704{
   1705	struct net *net = sock_net(msg->sk);
   1706	struct mptcp_pm_addr_entry *entry;
   1707	struct pm_nl_pernet *pernet;
   1708	int id = cb->args[0];
   1709	void *hdr;
   1710	int i;
   1711
   1712	pernet = pm_nl_get_pernet(net);
   1713
   1714	spin_lock_bh(&pernet->lock);
   1715	for (i = id; i < MPTCP_PM_MAX_ADDR_ID + 1; i++) {
   1716		if (test_bit(i, pernet->id_bitmap)) {
   1717			entry = __lookup_addr_by_id(pernet, i);
   1718			if (!entry)
   1719				break;
   1720
   1721			if (entry->addr.id <= id)
   1722				continue;
   1723
   1724			hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
   1725					  cb->nlh->nlmsg_seq, &mptcp_genl_family,
   1726					  NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
   1727			if (!hdr)
   1728				break;
   1729
   1730			if (mptcp_nl_fill_addr(msg, entry) < 0) {
   1731				genlmsg_cancel(msg, hdr);
   1732				break;
   1733			}
   1734
   1735			id = entry->addr.id;
   1736			genlmsg_end(msg, hdr);
   1737		}
   1738	}
   1739	spin_unlock_bh(&pernet->lock);
   1740
   1741	cb->args[0] = id;
   1742	return msg->len;
   1743}
   1744
   1745static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
   1746{
   1747	struct nlattr *attr = info->attrs[id];
   1748
   1749	if (!attr)
   1750		return 0;
   1751
   1752	*limit = nla_get_u32(attr);
   1753	if (*limit > MPTCP_PM_ADDR_MAX) {
   1754		GENL_SET_ERR_MSG(info, "limit greater than maximum");
   1755		return -EINVAL;
   1756	}
   1757	return 0;
   1758}
   1759
   1760static int
   1761mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
   1762{
   1763	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1764	unsigned int rcv_addrs, subflows;
   1765	int ret;
   1766
   1767	spin_lock_bh(&pernet->lock);
   1768	rcv_addrs = pernet->add_addr_accept_max;
   1769	ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
   1770	if (ret)
   1771		goto unlock;
   1772
   1773	subflows = pernet->subflows_max;
   1774	ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
   1775	if (ret)
   1776		goto unlock;
   1777
   1778	WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
   1779	WRITE_ONCE(pernet->subflows_max, subflows);
   1780
   1781unlock:
   1782	spin_unlock_bh(&pernet->lock);
   1783	return ret;
   1784}
   1785
   1786static int
   1787mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
   1788{
   1789	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1790	struct sk_buff *msg;
   1791	void *reply;
   1792
   1793	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
   1794	if (!msg)
   1795		return -ENOMEM;
   1796
   1797	reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
   1798				  MPTCP_PM_CMD_GET_LIMITS);
   1799	if (!reply)
   1800		goto fail;
   1801
   1802	if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
   1803			READ_ONCE(pernet->add_addr_accept_max)))
   1804		goto fail;
   1805
   1806	if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
   1807			READ_ONCE(pernet->subflows_max)))
   1808		goto fail;
   1809
   1810	genlmsg_end(msg, reply);
   1811	return genlmsg_reply(msg, info);
   1812
   1813fail:
   1814	GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
   1815	nlmsg_free(msg);
   1816	return -EMSGSIZE;
   1817}
   1818
   1819static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk,
   1820				 struct mptcp_addr_info *addr)
   1821{
   1822	struct mptcp_rm_list list = { .nr = 0 };
   1823
   1824	list.ids[list.nr++] = addr->id;
   1825
   1826	spin_lock_bh(&msk->pm.lock);
   1827	mptcp_pm_nl_rm_subflow_received(msk, &list);
   1828	mptcp_pm_create_subflow_or_signal_addr(msk);
   1829	spin_unlock_bh(&msk->pm.lock);
   1830}
   1831
   1832static int mptcp_nl_set_flags(struct net *net,
   1833			      struct mptcp_addr_info *addr,
   1834			      u8 bkup, u8 changed)
   1835{
   1836	long s_slot = 0, s_num = 0;
   1837	struct mptcp_sock *msk;
   1838	int ret = -EINVAL;
   1839
   1840	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
   1841		struct sock *sk = (struct sock *)msk;
   1842
   1843		if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk))
   1844			goto next;
   1845
   1846		lock_sock(sk);
   1847		if (changed & MPTCP_PM_ADDR_FLAG_BACKUP)
   1848			ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, NULL, bkup);
   1849		if (changed & MPTCP_PM_ADDR_FLAG_FULLMESH)
   1850			mptcp_pm_nl_fullmesh(msk, addr);
   1851		release_sock(sk);
   1852
   1853next:
   1854		sock_put(sk);
   1855		cond_resched();
   1856	}
   1857
   1858	return ret;
   1859}
   1860
   1861static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
   1862{
   1863	struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry;
   1864	struct mptcp_pm_addr_entry remote = { .addr = { .family = AF_UNSPEC }, };
   1865	struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
   1866	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
   1867	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
   1868	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
   1869	u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP |
   1870			   MPTCP_PM_ADDR_FLAG_FULLMESH;
   1871	struct net *net = sock_net(skb->sk);
   1872	u8 bkup = 0, lookup_by_id = 0;
   1873	int ret;
   1874
   1875	ret = mptcp_pm_parse_entry(attr, info, false, &addr);
   1876	if (ret < 0)
   1877		return ret;
   1878
   1879	if (attr_rem) {
   1880		ret = mptcp_pm_parse_entry(attr_rem, info, false, &remote);
   1881		if (ret < 0)
   1882			return ret;
   1883	}
   1884
   1885	if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
   1886		bkup = 1;
   1887	if (addr.addr.family == AF_UNSPEC) {
   1888		lookup_by_id = 1;
   1889		if (!addr.addr.id)
   1890			return -EOPNOTSUPP;
   1891	}
   1892
   1893	if (token)
   1894		return mptcp_userspace_pm_set_flags(sock_net(skb->sk),
   1895						    token, &addr, &remote, bkup);
   1896
   1897	spin_lock_bh(&pernet->lock);
   1898	entry = __lookup_addr(pernet, &addr.addr, lookup_by_id);
   1899	if (!entry) {
   1900		spin_unlock_bh(&pernet->lock);
   1901		return -EINVAL;
   1902	}
   1903	if ((addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) &&
   1904	    (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
   1905		spin_unlock_bh(&pernet->lock);
   1906		return -EINVAL;
   1907	}
   1908
   1909	changed = (addr.flags ^ entry->flags) & mask;
   1910	entry->flags = (entry->flags & ~mask) | (addr.flags & mask);
   1911	addr = *entry;
   1912	spin_unlock_bh(&pernet->lock);
   1913
   1914	mptcp_nl_set_flags(net, &addr.addr, bkup, changed);
   1915	return 0;
   1916}
   1917
   1918static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
   1919{
   1920	genlmsg_multicast_netns(&mptcp_genl_family, net,
   1921				nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
   1922}
   1923
   1924bool mptcp_userspace_pm_active(const struct mptcp_sock *msk)
   1925{
   1926	return genl_has_listeners(&mptcp_genl_family,
   1927				  sock_net((const struct sock *)msk),
   1928				  MPTCP_PM_EV_GRP_OFFSET);
   1929}
   1930
   1931static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
   1932{
   1933	const struct inet_sock *issk = inet_sk(ssk);
   1934	const struct mptcp_subflow_context *sf;
   1935
   1936	if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
   1937		return -EMSGSIZE;
   1938
   1939	switch (ssk->sk_family) {
   1940	case AF_INET:
   1941		if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
   1942			return -EMSGSIZE;
   1943		if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
   1944			return -EMSGSIZE;
   1945		break;
   1946#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   1947	case AF_INET6: {
   1948		const struct ipv6_pinfo *np = inet6_sk(ssk);
   1949
   1950		if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
   1951			return -EMSGSIZE;
   1952		if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
   1953			return -EMSGSIZE;
   1954		break;
   1955	}
   1956#endif
   1957	default:
   1958		WARN_ON_ONCE(1);
   1959		return -EMSGSIZE;
   1960	}
   1961
   1962	if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
   1963		return -EMSGSIZE;
   1964	if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
   1965		return -EMSGSIZE;
   1966
   1967	sf = mptcp_subflow_ctx(ssk);
   1968	if (WARN_ON_ONCE(!sf))
   1969		return -EINVAL;
   1970
   1971	if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
   1972		return -EMSGSIZE;
   1973
   1974	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
   1975		return -EMSGSIZE;
   1976
   1977	return 0;
   1978}
   1979
   1980static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
   1981					 const struct mptcp_sock *msk,
   1982					 const struct sock *ssk)
   1983{
   1984	const struct sock *sk = (const struct sock *)msk;
   1985	const struct mptcp_subflow_context *sf;
   1986	u8 sk_err;
   1987
   1988	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
   1989		return -EMSGSIZE;
   1990
   1991	if (mptcp_event_add_subflow(skb, ssk))
   1992		return -EMSGSIZE;
   1993
   1994	sf = mptcp_subflow_ctx(ssk);
   1995	if (WARN_ON_ONCE(!sf))
   1996		return -EINVAL;
   1997
   1998	if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
   1999		return -EMSGSIZE;
   2000
   2001	if (ssk->sk_bound_dev_if &&
   2002	    nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
   2003		return -EMSGSIZE;
   2004
   2005	sk_err = ssk->sk_err;
   2006	if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
   2007	    nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
   2008		return -EMSGSIZE;
   2009
   2010	return 0;
   2011}
   2012
   2013static int mptcp_event_sub_established(struct sk_buff *skb,
   2014				       const struct mptcp_sock *msk,
   2015				       const struct sock *ssk)
   2016{
   2017	return mptcp_event_put_token_and_ssk(skb, msk, ssk);
   2018}
   2019
   2020static int mptcp_event_sub_closed(struct sk_buff *skb,
   2021				  const struct mptcp_sock *msk,
   2022				  const struct sock *ssk)
   2023{
   2024	const struct mptcp_subflow_context *sf;
   2025
   2026	if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
   2027		return -EMSGSIZE;
   2028
   2029	sf = mptcp_subflow_ctx(ssk);
   2030	if (!sf->reset_seen)
   2031		return 0;
   2032
   2033	if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
   2034		return -EMSGSIZE;
   2035
   2036	if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
   2037		return -EMSGSIZE;
   2038
   2039	return 0;
   2040}
   2041
   2042static int mptcp_event_created(struct sk_buff *skb,
   2043			       const struct mptcp_sock *msk,
   2044			       const struct sock *ssk)
   2045{
   2046	int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
   2047
   2048	if (err)
   2049		return err;
   2050
   2051	if (nla_put_u8(skb, MPTCP_ATTR_SERVER_SIDE, READ_ONCE(msk->pm.server_side)))
   2052		return -EMSGSIZE;
   2053
   2054	return mptcp_event_add_subflow(skb, ssk);
   2055}
   2056
   2057void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
   2058{
   2059	struct net *net = sock_net((const struct sock *)msk);
   2060	struct nlmsghdr *nlh;
   2061	struct sk_buff *skb;
   2062
   2063	if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
   2064		return;
   2065
   2066	skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
   2067	if (!skb)
   2068		return;
   2069
   2070	nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
   2071	if (!nlh)
   2072		goto nla_put_failure;
   2073
   2074	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
   2075		goto nla_put_failure;
   2076
   2077	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
   2078		goto nla_put_failure;
   2079
   2080	genlmsg_end(skb, nlh);
   2081	mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
   2082	return;
   2083
   2084nla_put_failure:
   2085	kfree_skb(skb);
   2086}
   2087
   2088void mptcp_event_addr_announced(const struct sock *ssk,
   2089				const struct mptcp_addr_info *info)
   2090{
   2091	struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
   2092	struct mptcp_sock *msk = mptcp_sk(subflow->conn);
   2093	struct net *net = sock_net(ssk);
   2094	struct nlmsghdr *nlh;
   2095	struct sk_buff *skb;
   2096
   2097	if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
   2098		return;
   2099
   2100	skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
   2101	if (!skb)
   2102		return;
   2103
   2104	nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
   2105			  MPTCP_EVENT_ANNOUNCED);
   2106	if (!nlh)
   2107		goto nla_put_failure;
   2108
   2109	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
   2110		goto nla_put_failure;
   2111
   2112	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
   2113		goto nla_put_failure;
   2114
   2115	if (nla_put_be16(skb, MPTCP_ATTR_DPORT,
   2116			 info->port == 0 ?
   2117			 inet_sk(ssk)->inet_dport :
   2118			 info->port))
   2119		goto nla_put_failure;
   2120
   2121	switch (info->family) {
   2122	case AF_INET:
   2123		if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
   2124			goto nla_put_failure;
   2125		break;
   2126#if IS_ENABLED(CONFIG_MPTCP_IPV6)
   2127	case AF_INET6:
   2128		if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
   2129			goto nla_put_failure;
   2130		break;
   2131#endif
   2132	default:
   2133		WARN_ON_ONCE(1);
   2134		goto nla_put_failure;
   2135	}
   2136
   2137	genlmsg_end(skb, nlh);
   2138	mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
   2139	return;
   2140
   2141nla_put_failure:
   2142	kfree_skb(skb);
   2143}
   2144
   2145void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
   2146		 const struct sock *ssk, gfp_t gfp)
   2147{
   2148	struct net *net = sock_net((const struct sock *)msk);
   2149	struct nlmsghdr *nlh;
   2150	struct sk_buff *skb;
   2151
   2152	if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
   2153		return;
   2154
   2155	skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
   2156	if (!skb)
   2157		return;
   2158
   2159	nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
   2160	if (!nlh)
   2161		goto nla_put_failure;
   2162
   2163	switch (type) {
   2164	case MPTCP_EVENT_UNSPEC:
   2165		WARN_ON_ONCE(1);
   2166		break;
   2167	case MPTCP_EVENT_CREATED:
   2168	case MPTCP_EVENT_ESTABLISHED:
   2169		if (mptcp_event_created(skb, msk, ssk) < 0)
   2170			goto nla_put_failure;
   2171		break;
   2172	case MPTCP_EVENT_CLOSED:
   2173		if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
   2174			goto nla_put_failure;
   2175		break;
   2176	case MPTCP_EVENT_ANNOUNCED:
   2177	case MPTCP_EVENT_REMOVED:
   2178		/* call mptcp_event_addr_announced()/removed instead */
   2179		WARN_ON_ONCE(1);
   2180		break;
   2181	case MPTCP_EVENT_SUB_ESTABLISHED:
   2182	case MPTCP_EVENT_SUB_PRIORITY:
   2183		if (mptcp_event_sub_established(skb, msk, ssk) < 0)
   2184			goto nla_put_failure;
   2185		break;
   2186	case MPTCP_EVENT_SUB_CLOSED:
   2187		if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
   2188			goto nla_put_failure;
   2189		break;
   2190	}
   2191
   2192	genlmsg_end(skb, nlh);
   2193	mptcp_nl_mcast_send(net, skb, gfp);
   2194	return;
   2195
   2196nla_put_failure:
   2197	kfree_skb(skb);
   2198}
   2199
   2200static const struct genl_small_ops mptcp_pm_ops[] = {
   2201	{
   2202		.cmd    = MPTCP_PM_CMD_ADD_ADDR,
   2203		.doit   = mptcp_nl_cmd_add_addr,
   2204		.flags  = GENL_ADMIN_PERM,
   2205	},
   2206	{
   2207		.cmd    = MPTCP_PM_CMD_DEL_ADDR,
   2208		.doit   = mptcp_nl_cmd_del_addr,
   2209		.flags  = GENL_ADMIN_PERM,
   2210	},
   2211	{
   2212		.cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
   2213		.doit   = mptcp_nl_cmd_flush_addrs,
   2214		.flags  = GENL_ADMIN_PERM,
   2215	},
   2216	{
   2217		.cmd    = MPTCP_PM_CMD_GET_ADDR,
   2218		.doit   = mptcp_nl_cmd_get_addr,
   2219		.dumpit   = mptcp_nl_cmd_dump_addrs,
   2220	},
   2221	{
   2222		.cmd    = MPTCP_PM_CMD_SET_LIMITS,
   2223		.doit   = mptcp_nl_cmd_set_limits,
   2224		.flags  = GENL_ADMIN_PERM,
   2225	},
   2226	{
   2227		.cmd    = MPTCP_PM_CMD_GET_LIMITS,
   2228		.doit   = mptcp_nl_cmd_get_limits,
   2229	},
   2230	{
   2231		.cmd    = MPTCP_PM_CMD_SET_FLAGS,
   2232		.doit   = mptcp_nl_cmd_set_flags,
   2233		.flags  = GENL_ADMIN_PERM,
   2234	},
   2235	{
   2236		.cmd    = MPTCP_PM_CMD_ANNOUNCE,
   2237		.doit   = mptcp_nl_cmd_announce,
   2238		.flags  = GENL_ADMIN_PERM,
   2239	},
   2240	{
   2241		.cmd    = MPTCP_PM_CMD_REMOVE,
   2242		.doit   = mptcp_nl_cmd_remove,
   2243		.flags  = GENL_ADMIN_PERM,
   2244	},
   2245	{
   2246		.cmd    = MPTCP_PM_CMD_SUBFLOW_CREATE,
   2247		.doit   = mptcp_nl_cmd_sf_create,
   2248		.flags  = GENL_ADMIN_PERM,
   2249	},
   2250	{
   2251		.cmd    = MPTCP_PM_CMD_SUBFLOW_DESTROY,
   2252		.doit   = mptcp_nl_cmd_sf_destroy,
   2253		.flags  = GENL_ADMIN_PERM,
   2254	},
   2255};
   2256
   2257static struct genl_family mptcp_genl_family __ro_after_init = {
   2258	.name		= MPTCP_PM_NAME,
   2259	.version	= MPTCP_PM_VER,
   2260	.maxattr	= MPTCP_PM_ATTR_MAX,
   2261	.policy		= mptcp_pm_policy,
   2262	.netnsok	= true,
   2263	.module		= THIS_MODULE,
   2264	.small_ops	= mptcp_pm_ops,
   2265	.n_small_ops	= ARRAY_SIZE(mptcp_pm_ops),
   2266	.mcgrps		= mptcp_pm_mcgrps,
   2267	.n_mcgrps	= ARRAY_SIZE(mptcp_pm_mcgrps),
   2268};
   2269
   2270static int __net_init pm_nl_init_net(struct net *net)
   2271{
   2272	struct pm_nl_pernet *pernet = pm_nl_get_pernet(net);
   2273
   2274	INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
   2275
   2276	/* Cit. 2 subflows ought to be enough for anybody. */
   2277	pernet->subflows_max = 2;
   2278	pernet->next_id = 1;
   2279	pernet->stale_loss_cnt = 4;
   2280	spin_lock_init(&pernet->lock);
   2281
   2282	/* No need to initialize other pernet fields, the struct is zeroed at
   2283	 * allocation time.
   2284	 */
   2285
   2286	return 0;
   2287}
   2288
   2289static void __net_exit pm_nl_exit_net(struct list_head *net_list)
   2290{
   2291	struct net *net;
   2292
   2293	list_for_each_entry(net, net_list, exit_list) {
   2294		struct pm_nl_pernet *pernet = pm_nl_get_pernet(net);
   2295
   2296		/* net is removed from namespace list, can't race with
   2297		 * other modifiers, also netns core already waited for a
   2298		 * RCU grace period.
   2299		 */
   2300		__flush_addrs(&pernet->local_addr_list);
   2301	}
   2302}
   2303
   2304static struct pernet_operations mptcp_pm_pernet_ops = {
   2305	.init = pm_nl_init_net,
   2306	.exit_batch = pm_nl_exit_net,
   2307	.id = &pm_nl_pernet_id,
   2308	.size = sizeof(struct pm_nl_pernet),
   2309};
   2310
   2311void __init mptcp_pm_nl_init(void)
   2312{
   2313	if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
   2314		panic("Failed to register MPTCP PM pernet subsystem.\n");
   2315
   2316	if (genl_register_family(&mptcp_genl_family))
   2317		panic("Failed to register MPTCP PM netlink family\n");
   2318}