cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

pm_userspace.c (10974B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Multipath TCP
      3 *
      4 * Copyright (c) 2022, Intel Corporation.
      5 */
      6
      7#include "protocol.h"
      8#include "mib.h"
      9
     10void mptcp_free_local_addr_list(struct mptcp_sock *msk)
     11{
     12	struct mptcp_pm_addr_entry *entry, *tmp;
     13	struct sock *sk = (struct sock *)msk;
     14	LIST_HEAD(free_list);
     15
     16	if (!mptcp_pm_is_userspace(msk))
     17		return;
     18
     19	spin_lock_bh(&msk->pm.lock);
     20	list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
     21	spin_unlock_bh(&msk->pm.lock);
     22
     23	list_for_each_entry_safe(entry, tmp, &free_list, list) {
     24		sock_kfree_s(sk, entry, sizeof(*entry));
     25	}
     26}
     27
     28int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
     29					     struct mptcp_pm_addr_entry *entry)
     30{
     31	DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
     32	struct mptcp_pm_addr_entry *match = NULL;
     33	struct sock *sk = (struct sock *)msk;
     34	struct mptcp_pm_addr_entry *e;
     35	bool addr_match = false;
     36	bool id_match = false;
     37	int ret = -EINVAL;
     38
     39	bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
     40
     41	spin_lock_bh(&msk->pm.lock);
     42	list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
     43		addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
     44		if (addr_match && entry->addr.id == 0)
     45			entry->addr.id = e->addr.id;
     46		id_match = (e->addr.id == entry->addr.id);
     47		if (addr_match && id_match) {
     48			match = e;
     49			break;
     50		} else if (addr_match || id_match) {
     51			break;
     52		}
     53		__set_bit(e->addr.id, id_bitmap);
     54	}
     55
     56	if (!match && !addr_match && !id_match) {
     57		/* Memory for the entry is allocated from the
     58		 * sock option buffer.
     59		 */
     60		e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
     61		if (!e) {
     62			spin_unlock_bh(&msk->pm.lock);
     63			return -ENOMEM;
     64		}
     65
     66		*e = *entry;
     67		if (!e->addr.id)
     68			e->addr.id = find_next_zero_bit(id_bitmap,
     69							MPTCP_PM_MAX_ADDR_ID + 1,
     70							1);
     71		list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
     72		ret = e->addr.id;
     73	} else if (match) {
     74		ret = entry->addr.id;
     75	}
     76
     77	spin_unlock_bh(&msk->pm.lock);
     78	return ret;
     79}
     80
     81int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
     82						   unsigned int id,
     83						   u8 *flags, int *ifindex)
     84{
     85	struct mptcp_pm_addr_entry *entry, *match = NULL;
     86
     87	*flags = 0;
     88	*ifindex = 0;
     89
     90	spin_lock_bh(&msk->pm.lock);
     91	list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
     92		if (id == entry->addr.id) {
     93			match = entry;
     94			break;
     95		}
     96	}
     97	spin_unlock_bh(&msk->pm.lock);
     98	if (match) {
     99		*flags = match->flags;
    100		*ifindex = match->ifindex;
    101	}
    102
    103	return 0;
    104}
    105
    106int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
    107				    struct mptcp_addr_info *skc)
    108{
    109	struct mptcp_pm_addr_entry new_entry;
    110	__be16 msk_sport =  ((struct inet_sock *)
    111			     inet_sk((struct sock *)msk))->inet_sport;
    112
    113	memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
    114	new_entry.addr = *skc;
    115	new_entry.addr.id = 0;
    116	new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
    117
    118	if (new_entry.addr.port == msk_sport)
    119		new_entry.addr.port = 0;
    120
    121	return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry);
    122}
    123
    124int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info)
    125{
    126	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
    127	struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR];
    128	struct mptcp_pm_addr_entry addr_val;
    129	struct mptcp_sock *msk;
    130	int err = -EINVAL;
    131	u32 token_val;
    132
    133	if (!addr || !token) {
    134		GENL_SET_ERR_MSG(info, "missing required inputs");
    135		return err;
    136	}
    137
    138	token_val = nla_get_u32(token);
    139
    140	msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
    141	if (!msk) {
    142		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
    143		return err;
    144	}
    145
    146	if (!mptcp_pm_is_userspace(msk)) {
    147		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
    148		goto announce_err;
    149	}
    150
    151	err = mptcp_pm_parse_entry(addr, info, true, &addr_val);
    152	if (err < 0) {
    153		GENL_SET_ERR_MSG(info, "error parsing local address");
    154		goto announce_err;
    155	}
    156
    157	if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
    158		GENL_SET_ERR_MSG(info, "invalid addr id or flags");
    159		goto announce_err;
    160	}
    161
    162	err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val);
    163	if (err < 0) {
    164		GENL_SET_ERR_MSG(info, "did not match address and id");
    165		goto announce_err;
    166	}
    167
    168	lock_sock((struct sock *)msk);
    169	spin_lock_bh(&msk->pm.lock);
    170
    171	if (mptcp_pm_alloc_anno_list(msk, &addr_val)) {
    172		mptcp_pm_announce_addr(msk, &addr_val.addr, false);
    173		mptcp_pm_nl_addr_send_ack(msk);
    174	}
    175
    176	spin_unlock_bh(&msk->pm.lock);
    177	release_sock((struct sock *)msk);
    178
    179	err = 0;
    180 announce_err:
    181	sock_put((struct sock *)msk);
    182	return err;
    183}
    184
    185int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info)
    186{
    187	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
    188	struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID];
    189	struct mptcp_pm_addr_entry *match = NULL;
    190	struct mptcp_pm_addr_entry *entry;
    191	struct mptcp_sock *msk;
    192	LIST_HEAD(free_list);
    193	int err = -EINVAL;
    194	u32 token_val;
    195	u8 id_val;
    196
    197	if (!id || !token) {
    198		GENL_SET_ERR_MSG(info, "missing required inputs");
    199		return err;
    200	}
    201
    202	id_val = nla_get_u8(id);
    203	token_val = nla_get_u32(token);
    204
    205	msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
    206	if (!msk) {
    207		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
    208		return err;
    209	}
    210
    211	if (!mptcp_pm_is_userspace(msk)) {
    212		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
    213		goto remove_err;
    214	}
    215
    216	lock_sock((struct sock *)msk);
    217
    218	list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
    219		if (entry->addr.id == id_val) {
    220			match = entry;
    221			break;
    222		}
    223	}
    224
    225	if (!match) {
    226		GENL_SET_ERR_MSG(info, "address with specified id not found");
    227		release_sock((struct sock *)msk);
    228		goto remove_err;
    229	}
    230
    231	list_move(&match->list, &free_list);
    232
    233	mptcp_pm_remove_addrs_and_subflows(msk, &free_list);
    234
    235	release_sock((struct sock *)msk);
    236
    237	list_for_each_entry_safe(match, entry, &free_list, list) {
    238		sock_kfree_s((struct sock *)msk, match, sizeof(*match));
    239	}
    240
    241	err = 0;
    242 remove_err:
    243	sock_put((struct sock *)msk);
    244	return err;
    245}
    246
    247int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info)
    248{
    249	struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
    250	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
    251	struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
    252	struct mptcp_addr_info addr_r;
    253	struct mptcp_addr_info addr_l;
    254	struct mptcp_sock *msk;
    255	int err = -EINVAL;
    256	struct sock *sk;
    257	u32 token_val;
    258
    259	if (!laddr || !raddr || !token) {
    260		GENL_SET_ERR_MSG(info, "missing required inputs");
    261		return err;
    262	}
    263
    264	token_val = nla_get_u32(token);
    265
    266	msk = mptcp_token_get_sock(genl_info_net(info), token_val);
    267	if (!msk) {
    268		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
    269		return err;
    270	}
    271
    272	if (!mptcp_pm_is_userspace(msk)) {
    273		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
    274		goto create_err;
    275	}
    276
    277	err = mptcp_pm_parse_addr(laddr, info, &addr_l);
    278	if (err < 0) {
    279		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
    280		goto create_err;
    281	}
    282
    283	if (addr_l.id == 0) {
    284		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id");
    285		goto create_err;
    286	}
    287
    288	err = mptcp_pm_parse_addr(raddr, info, &addr_r);
    289	if (err < 0) {
    290		NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
    291		goto create_err;
    292	}
    293
    294	sk = &msk->sk.icsk_inet.sk;
    295	lock_sock(sk);
    296
    297	err = __mptcp_subflow_connect(sk, &addr_l, &addr_r);
    298
    299	release_sock(sk);
    300
    301 create_err:
    302	sock_put((struct sock *)msk);
    303	return err;
    304}
    305
    306static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk,
    307				      const struct mptcp_addr_info *local,
    308				      const struct mptcp_addr_info *remote)
    309{
    310	struct mptcp_subflow_context *subflow;
    311
    312	if (local->family != remote->family)
    313		return NULL;
    314
    315	mptcp_for_each_subflow(msk, subflow) {
    316		const struct inet_sock *issk;
    317		struct sock *ssk;
    318
    319		ssk = mptcp_subflow_tcp_sock(subflow);
    320
    321		if (local->family != ssk->sk_family)
    322			continue;
    323
    324		issk = inet_sk(ssk);
    325
    326		switch (ssk->sk_family) {
    327		case AF_INET:
    328			if (issk->inet_saddr != local->addr.s_addr ||
    329			    issk->inet_daddr != remote->addr.s_addr)
    330				continue;
    331			break;
    332#if IS_ENABLED(CONFIG_MPTCP_IPV6)
    333		case AF_INET6: {
    334			const struct ipv6_pinfo *pinfo = inet6_sk(ssk);
    335
    336			if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) ||
    337			    !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr))
    338				continue;
    339			break;
    340		}
    341#endif
    342		default:
    343			continue;
    344		}
    345
    346		if (issk->inet_sport == local->port &&
    347		    issk->inet_dport == remote->port)
    348			return ssk;
    349	}
    350
    351	return NULL;
    352}
    353
    354int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info)
    355{
    356	struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
    357	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
    358	struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
    359	struct mptcp_addr_info addr_l;
    360	struct mptcp_addr_info addr_r;
    361	struct mptcp_sock *msk;
    362	struct sock *sk, *ssk;
    363	int err = -EINVAL;
    364	u32 token_val;
    365
    366	if (!laddr || !raddr || !token) {
    367		GENL_SET_ERR_MSG(info, "missing required inputs");
    368		return err;
    369	}
    370
    371	token_val = nla_get_u32(token);
    372
    373	msk = mptcp_token_get_sock(genl_info_net(info), token_val);
    374	if (!msk) {
    375		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
    376		return err;
    377	}
    378
    379	if (!mptcp_pm_is_userspace(msk)) {
    380		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
    381		goto destroy_err;
    382	}
    383
    384	err = mptcp_pm_parse_addr(laddr, info, &addr_l);
    385	if (err < 0) {
    386		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
    387		goto destroy_err;
    388	}
    389
    390	err = mptcp_pm_parse_addr(raddr, info, &addr_r);
    391	if (err < 0) {
    392		NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
    393		goto destroy_err;
    394	}
    395
    396	if (addr_l.family != addr_r.family) {
    397		GENL_SET_ERR_MSG(info, "address families do not match");
    398		goto destroy_err;
    399	}
    400
    401	if (!addr_l.port || !addr_r.port) {
    402		GENL_SET_ERR_MSG(info, "missing local or remote port");
    403		goto destroy_err;
    404	}
    405
    406	sk = &msk->sk.icsk_inet.sk;
    407	lock_sock(sk);
    408	ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r);
    409	if (ssk) {
    410		struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
    411
    412		mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN);
    413		mptcp_close_ssk(sk, ssk, subflow);
    414		MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW);
    415		err = 0;
    416	} else {
    417		err = -ESRCH;
    418	}
    419	release_sock(sk);
    420
    421destroy_err:
    422	sock_put((struct sock *)msk);
    423	return err;
    424}
    425
    426int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token,
    427				 struct mptcp_pm_addr_entry *loc,
    428				 struct mptcp_pm_addr_entry *rem, u8 bkup)
    429{
    430	struct mptcp_sock *msk;
    431	int ret = -EINVAL;
    432	u32 token_val;
    433
    434	token_val = nla_get_u32(token);
    435
    436	msk = mptcp_token_get_sock(net, token_val);
    437	if (!msk)
    438		return ret;
    439
    440	if (!mptcp_pm_is_userspace(msk))
    441		goto set_flags_err;
    442
    443	if (loc->addr.family == AF_UNSPEC ||
    444	    rem->addr.family == AF_UNSPEC)
    445		goto set_flags_err;
    446
    447	lock_sock((struct sock *)msk);
    448	ret = mptcp_pm_nl_mp_prio_send_ack(msk, &loc->addr, &rem->addr, bkup);
    449	release_sock((struct sock *)msk);
    450
    451set_flags_err:
    452	sock_put((struct sock *)msk);
    453	return ret;
    454}