cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sock_map.c (39759B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
      3
      4#include <linux/bpf.h>
      5#include <linux/btf_ids.h>
      6#include <linux/filter.h>
      7#include <linux/errno.h>
      8#include <linux/file.h>
      9#include <linux/net.h>
     10#include <linux/workqueue.h>
     11#include <linux/skmsg.h>
     12#include <linux/list.h>
     13#include <linux/jhash.h>
     14#include <linux/sock_diag.h>
     15#include <net/udp.h>
     16
     17struct bpf_stab {
     18	struct bpf_map map;
     19	struct sock **sks;
     20	struct sk_psock_progs progs;
     21	raw_spinlock_t lock;
     22};
     23
     24#define SOCK_CREATE_FLAG_MASK				\
     25	(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
     26
     27static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
     28				struct bpf_prog *old, u32 which);
     29static struct sk_psock_progs *sock_map_progs(struct bpf_map *map);
     30
     31static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
     32{
     33	struct bpf_stab *stab;
     34
     35	if (!capable(CAP_NET_ADMIN))
     36		return ERR_PTR(-EPERM);
     37	if (attr->max_entries == 0 ||
     38	    attr->key_size    != 4 ||
     39	    (attr->value_size != sizeof(u32) &&
     40	     attr->value_size != sizeof(u64)) ||
     41	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
     42		return ERR_PTR(-EINVAL);
     43
     44	stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT);
     45	if (!stab)
     46		return ERR_PTR(-ENOMEM);
     47
     48	bpf_map_init_from_attr(&stab->map, attr);
     49	raw_spin_lock_init(&stab->lock);
     50
     51	stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries *
     52				       sizeof(struct sock *),
     53				       stab->map.numa_node);
     54	if (!stab->sks) {
     55		kfree(stab);
     56		return ERR_PTR(-ENOMEM);
     57	}
     58
     59	return &stab->map;
     60}
     61
     62int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
     63{
     64	u32 ufd = attr->target_fd;
     65	struct bpf_map *map;
     66	struct fd f;
     67	int ret;
     68
     69	if (attr->attach_flags || attr->replace_bpf_fd)
     70		return -EINVAL;
     71
     72	f = fdget(ufd);
     73	map = __bpf_map_get(f);
     74	if (IS_ERR(map))
     75		return PTR_ERR(map);
     76	ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
     77	fdput(f);
     78	return ret;
     79}
     80
     81int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
     82{
     83	u32 ufd = attr->target_fd;
     84	struct bpf_prog *prog;
     85	struct bpf_map *map;
     86	struct fd f;
     87	int ret;
     88
     89	if (attr->attach_flags || attr->replace_bpf_fd)
     90		return -EINVAL;
     91
     92	f = fdget(ufd);
     93	map = __bpf_map_get(f);
     94	if (IS_ERR(map))
     95		return PTR_ERR(map);
     96
     97	prog = bpf_prog_get(attr->attach_bpf_fd);
     98	if (IS_ERR(prog)) {
     99		ret = PTR_ERR(prog);
    100		goto put_map;
    101	}
    102
    103	if (prog->type != ptype) {
    104		ret = -EINVAL;
    105		goto put_prog;
    106	}
    107
    108	ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
    109put_prog:
    110	bpf_prog_put(prog);
    111put_map:
    112	fdput(f);
    113	return ret;
    114}
    115
    116static void sock_map_sk_acquire(struct sock *sk)
    117	__acquires(&sk->sk_lock.slock)
    118{
    119	lock_sock(sk);
    120	preempt_disable();
    121	rcu_read_lock();
    122}
    123
    124static void sock_map_sk_release(struct sock *sk)
    125	__releases(&sk->sk_lock.slock)
    126{
    127	rcu_read_unlock();
    128	preempt_enable();
    129	release_sock(sk);
    130}
    131
    132static void sock_map_add_link(struct sk_psock *psock,
    133			      struct sk_psock_link *link,
    134			      struct bpf_map *map, void *link_raw)
    135{
    136	link->link_raw = link_raw;
    137	link->map = map;
    138	spin_lock_bh(&psock->link_lock);
    139	list_add_tail(&link->list, &psock->link);
    140	spin_unlock_bh(&psock->link_lock);
    141}
    142
    143static void sock_map_del_link(struct sock *sk,
    144			      struct sk_psock *psock, void *link_raw)
    145{
    146	bool strp_stop = false, verdict_stop = false;
    147	struct sk_psock_link *link, *tmp;
    148
    149	spin_lock_bh(&psock->link_lock);
    150	list_for_each_entry_safe(link, tmp, &psock->link, list) {
    151		if (link->link_raw == link_raw) {
    152			struct bpf_map *map = link->map;
    153			struct bpf_stab *stab = container_of(map, struct bpf_stab,
    154							     map);
    155			if (psock->saved_data_ready && stab->progs.stream_parser)
    156				strp_stop = true;
    157			if (psock->saved_data_ready && stab->progs.stream_verdict)
    158				verdict_stop = true;
    159			if (psock->saved_data_ready && stab->progs.skb_verdict)
    160				verdict_stop = true;
    161			list_del(&link->list);
    162			sk_psock_free_link(link);
    163		}
    164	}
    165	spin_unlock_bh(&psock->link_lock);
    166	if (strp_stop || verdict_stop) {
    167		write_lock_bh(&sk->sk_callback_lock);
    168		if (strp_stop)
    169			sk_psock_stop_strp(sk, psock);
    170		if (verdict_stop)
    171			sk_psock_stop_verdict(sk, psock);
    172
    173		if (psock->psock_update_sk_prot)
    174			psock->psock_update_sk_prot(sk, psock, false);
    175		write_unlock_bh(&sk->sk_callback_lock);
    176	}
    177}
    178
    179static void sock_map_unref(struct sock *sk, void *link_raw)
    180{
    181	struct sk_psock *psock = sk_psock(sk);
    182
    183	if (likely(psock)) {
    184		sock_map_del_link(sk, psock, link_raw);
    185		sk_psock_put(sk, psock);
    186	}
    187}
    188
    189static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
    190{
    191	if (!sk->sk_prot->psock_update_sk_prot)
    192		return -EINVAL;
    193	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
    194	return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
    195}
    196
    197static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
    198{
    199	struct sk_psock *psock;
    200
    201	rcu_read_lock();
    202	psock = sk_psock(sk);
    203	if (psock) {
    204		if (sk->sk_prot->close != sock_map_close) {
    205			psock = ERR_PTR(-EBUSY);
    206			goto out;
    207		}
    208
    209		if (!refcount_inc_not_zero(&psock->refcnt))
    210			psock = ERR_PTR(-EBUSY);
    211	}
    212out:
    213	rcu_read_unlock();
    214	return psock;
    215}
    216
    217static int sock_map_link(struct bpf_map *map, struct sock *sk)
    218{
    219	struct sk_psock_progs *progs = sock_map_progs(map);
    220	struct bpf_prog *stream_verdict = NULL;
    221	struct bpf_prog *stream_parser = NULL;
    222	struct bpf_prog *skb_verdict = NULL;
    223	struct bpf_prog *msg_parser = NULL;
    224	struct sk_psock *psock;
    225	int ret;
    226
    227	stream_verdict = READ_ONCE(progs->stream_verdict);
    228	if (stream_verdict) {
    229		stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
    230		if (IS_ERR(stream_verdict))
    231			return PTR_ERR(stream_verdict);
    232	}
    233
    234	stream_parser = READ_ONCE(progs->stream_parser);
    235	if (stream_parser) {
    236		stream_parser = bpf_prog_inc_not_zero(stream_parser);
    237		if (IS_ERR(stream_parser)) {
    238			ret = PTR_ERR(stream_parser);
    239			goto out_put_stream_verdict;
    240		}
    241	}
    242
    243	msg_parser = READ_ONCE(progs->msg_parser);
    244	if (msg_parser) {
    245		msg_parser = bpf_prog_inc_not_zero(msg_parser);
    246		if (IS_ERR(msg_parser)) {
    247			ret = PTR_ERR(msg_parser);
    248			goto out_put_stream_parser;
    249		}
    250	}
    251
    252	skb_verdict = READ_ONCE(progs->skb_verdict);
    253	if (skb_verdict) {
    254		skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
    255		if (IS_ERR(skb_verdict)) {
    256			ret = PTR_ERR(skb_verdict);
    257			goto out_put_msg_parser;
    258		}
    259	}
    260
    261	psock = sock_map_psock_get_checked(sk);
    262	if (IS_ERR(psock)) {
    263		ret = PTR_ERR(psock);
    264		goto out_progs;
    265	}
    266
    267	if (psock) {
    268		if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
    269		    (stream_parser  && READ_ONCE(psock->progs.stream_parser)) ||
    270		    (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
    271		    (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) ||
    272		    (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
    273		    (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
    274			sk_psock_put(sk, psock);
    275			ret = -EBUSY;
    276			goto out_progs;
    277		}
    278	} else {
    279		psock = sk_psock_init(sk, map->numa_node);
    280		if (IS_ERR(psock)) {
    281			ret = PTR_ERR(psock);
    282			goto out_progs;
    283		}
    284	}
    285
    286	if (msg_parser)
    287		psock_set_prog(&psock->progs.msg_parser, msg_parser);
    288	if (stream_parser)
    289		psock_set_prog(&psock->progs.stream_parser, stream_parser);
    290	if (stream_verdict)
    291		psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
    292	if (skb_verdict)
    293		psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
    294
    295	/* msg_* and stream_* programs references tracked in psock after this
    296	 * point. Reference dec and cleanup will occur through psock destructor
    297	 */
    298	ret = sock_map_init_proto(sk, psock);
    299	if (ret < 0) {
    300		sk_psock_put(sk, psock);
    301		goto out;
    302	}
    303
    304	write_lock_bh(&sk->sk_callback_lock);
    305	if (stream_parser && stream_verdict && !psock->saved_data_ready) {
    306		ret = sk_psock_init_strp(sk, psock);
    307		if (ret) {
    308			write_unlock_bh(&sk->sk_callback_lock);
    309			sk_psock_put(sk, psock);
    310			goto out;
    311		}
    312		sk_psock_start_strp(sk, psock);
    313	} else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
    314		sk_psock_start_verdict(sk,psock);
    315	} else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) {
    316		sk_psock_start_verdict(sk, psock);
    317	}
    318	write_unlock_bh(&sk->sk_callback_lock);
    319	return 0;
    320out_progs:
    321	if (skb_verdict)
    322		bpf_prog_put(skb_verdict);
    323out_put_msg_parser:
    324	if (msg_parser)
    325		bpf_prog_put(msg_parser);
    326out_put_stream_parser:
    327	if (stream_parser)
    328		bpf_prog_put(stream_parser);
    329out_put_stream_verdict:
    330	if (stream_verdict)
    331		bpf_prog_put(stream_verdict);
    332out:
    333	return ret;
    334}
    335
    336static void sock_map_free(struct bpf_map *map)
    337{
    338	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    339	int i;
    340
    341	/* After the sync no updates or deletes will be in-flight so it
    342	 * is safe to walk map and remove entries without risking a race
    343	 * in EEXIST update case.
    344	 */
    345	synchronize_rcu();
    346	for (i = 0; i < stab->map.max_entries; i++) {
    347		struct sock **psk = &stab->sks[i];
    348		struct sock *sk;
    349
    350		sk = xchg(psk, NULL);
    351		if (sk) {
    352			lock_sock(sk);
    353			rcu_read_lock();
    354			sock_map_unref(sk, psk);
    355			rcu_read_unlock();
    356			release_sock(sk);
    357		}
    358	}
    359
    360	/* wait for psock readers accessing its map link */
    361	synchronize_rcu();
    362
    363	bpf_map_area_free(stab->sks);
    364	kfree(stab);
    365}
    366
    367static void sock_map_release_progs(struct bpf_map *map)
    368{
    369	psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
    370}
    371
    372static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
    373{
    374	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    375
    376	WARN_ON_ONCE(!rcu_read_lock_held());
    377
    378	if (unlikely(key >= map->max_entries))
    379		return NULL;
    380	return READ_ONCE(stab->sks[key]);
    381}
    382
    383static void *sock_map_lookup(struct bpf_map *map, void *key)
    384{
    385	struct sock *sk;
    386
    387	sk = __sock_map_lookup_elem(map, *(u32 *)key);
    388	if (!sk)
    389		return NULL;
    390	if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
    391		return NULL;
    392	return sk;
    393}
    394
    395static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
    396{
    397	struct sock *sk;
    398
    399	if (map->value_size != sizeof(u64))
    400		return ERR_PTR(-ENOSPC);
    401
    402	sk = __sock_map_lookup_elem(map, *(u32 *)key);
    403	if (!sk)
    404		return ERR_PTR(-ENOENT);
    405
    406	__sock_gen_cookie(sk);
    407	return &sk->sk_cookie;
    408}
    409
    410static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
    411			     struct sock **psk)
    412{
    413	struct sock *sk;
    414	int err = 0;
    415
    416	raw_spin_lock_bh(&stab->lock);
    417	sk = *psk;
    418	if (!sk_test || sk_test == sk)
    419		sk = xchg(psk, NULL);
    420
    421	if (likely(sk))
    422		sock_map_unref(sk, psk);
    423	else
    424		err = -EINVAL;
    425
    426	raw_spin_unlock_bh(&stab->lock);
    427	return err;
    428}
    429
    430static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
    431				      void *link_raw)
    432{
    433	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    434
    435	__sock_map_delete(stab, sk, link_raw);
    436}
    437
    438static int sock_map_delete_elem(struct bpf_map *map, void *key)
    439{
    440	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    441	u32 i = *(u32 *)key;
    442	struct sock **psk;
    443
    444	if (unlikely(i >= map->max_entries))
    445		return -EINVAL;
    446
    447	psk = &stab->sks[i];
    448	return __sock_map_delete(stab, NULL, psk);
    449}
    450
    451static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
    452{
    453	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    454	u32 i = key ? *(u32 *)key : U32_MAX;
    455	u32 *key_next = next;
    456
    457	if (i == stab->map.max_entries - 1)
    458		return -ENOENT;
    459	if (i >= stab->map.max_entries)
    460		*key_next = 0;
    461	else
    462		*key_next = i + 1;
    463	return 0;
    464}
    465
    466static int sock_map_update_common(struct bpf_map *map, u32 idx,
    467				  struct sock *sk, u64 flags)
    468{
    469	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
    470	struct sk_psock_link *link;
    471	struct sk_psock *psock;
    472	struct sock *osk;
    473	int ret;
    474
    475	WARN_ON_ONCE(!rcu_read_lock_held());
    476	if (unlikely(flags > BPF_EXIST))
    477		return -EINVAL;
    478	if (unlikely(idx >= map->max_entries))
    479		return -E2BIG;
    480
    481	link = sk_psock_init_link();
    482	if (!link)
    483		return -ENOMEM;
    484
    485	ret = sock_map_link(map, sk);
    486	if (ret < 0)
    487		goto out_free;
    488
    489	psock = sk_psock(sk);
    490	WARN_ON_ONCE(!psock);
    491
    492	raw_spin_lock_bh(&stab->lock);
    493	osk = stab->sks[idx];
    494	if (osk && flags == BPF_NOEXIST) {
    495		ret = -EEXIST;
    496		goto out_unlock;
    497	} else if (!osk && flags == BPF_EXIST) {
    498		ret = -ENOENT;
    499		goto out_unlock;
    500	}
    501
    502	sock_map_add_link(psock, link, map, &stab->sks[idx]);
    503	stab->sks[idx] = sk;
    504	if (osk)
    505		sock_map_unref(osk, &stab->sks[idx]);
    506	raw_spin_unlock_bh(&stab->lock);
    507	return 0;
    508out_unlock:
    509	raw_spin_unlock_bh(&stab->lock);
    510	if (psock)
    511		sk_psock_put(sk, psock);
    512out_free:
    513	sk_psock_free_link(link);
    514	return ret;
    515}
    516
    517static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
    518{
    519	return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
    520	       ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
    521	       ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
    522}
    523
    524static bool sock_map_redirect_allowed(const struct sock *sk)
    525{
    526	if (sk_is_tcp(sk))
    527		return sk->sk_state != TCP_LISTEN;
    528	else
    529		return sk->sk_state == TCP_ESTABLISHED;
    530}
    531
    532static bool sock_map_sk_is_suitable(const struct sock *sk)
    533{
    534	return !!sk->sk_prot->psock_update_sk_prot;
    535}
    536
    537static bool sock_map_sk_state_allowed(const struct sock *sk)
    538{
    539	if (sk_is_tcp(sk))
    540		return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
    541	return true;
    542}
    543
    544static int sock_hash_update_common(struct bpf_map *map, void *key,
    545				   struct sock *sk, u64 flags);
    546
    547int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value,
    548			     u64 flags)
    549{
    550	struct socket *sock;
    551	struct sock *sk;
    552	int ret;
    553	u64 ufd;
    554
    555	if (map->value_size == sizeof(u64))
    556		ufd = *(u64 *)value;
    557	else
    558		ufd = *(u32 *)value;
    559	if (ufd > S32_MAX)
    560		return -EINVAL;
    561
    562	sock = sockfd_lookup(ufd, &ret);
    563	if (!sock)
    564		return ret;
    565	sk = sock->sk;
    566	if (!sk) {
    567		ret = -EINVAL;
    568		goto out;
    569	}
    570	if (!sock_map_sk_is_suitable(sk)) {
    571		ret = -EOPNOTSUPP;
    572		goto out;
    573	}
    574
    575	sock_map_sk_acquire(sk);
    576	if (!sock_map_sk_state_allowed(sk))
    577		ret = -EOPNOTSUPP;
    578	else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
    579		ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
    580	else
    581		ret = sock_hash_update_common(map, key, sk, flags);
    582	sock_map_sk_release(sk);
    583out:
    584	sockfd_put(sock);
    585	return ret;
    586}
    587
    588static int sock_map_update_elem(struct bpf_map *map, void *key,
    589				void *value, u64 flags)
    590{
    591	struct sock *sk = (struct sock *)value;
    592	int ret;
    593
    594	if (unlikely(!sk || !sk_fullsock(sk)))
    595		return -EINVAL;
    596
    597	if (!sock_map_sk_is_suitable(sk))
    598		return -EOPNOTSUPP;
    599
    600	local_bh_disable();
    601	bh_lock_sock(sk);
    602	if (!sock_map_sk_state_allowed(sk))
    603		ret = -EOPNOTSUPP;
    604	else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
    605		ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
    606	else
    607		ret = sock_hash_update_common(map, key, sk, flags);
    608	bh_unlock_sock(sk);
    609	local_bh_enable();
    610	return ret;
    611}
    612
    613BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
    614	   struct bpf_map *, map, void *, key, u64, flags)
    615{
    616	WARN_ON_ONCE(!rcu_read_lock_held());
    617
    618	if (likely(sock_map_sk_is_suitable(sops->sk) &&
    619		   sock_map_op_okay(sops)))
    620		return sock_map_update_common(map, *(u32 *)key, sops->sk,
    621					      flags);
    622	return -EOPNOTSUPP;
    623}
    624
    625const struct bpf_func_proto bpf_sock_map_update_proto = {
    626	.func		= bpf_sock_map_update,
    627	.gpl_only	= false,
    628	.pkt_access	= true,
    629	.ret_type	= RET_INTEGER,
    630	.arg1_type	= ARG_PTR_TO_CTX,
    631	.arg2_type	= ARG_CONST_MAP_PTR,
    632	.arg3_type	= ARG_PTR_TO_MAP_KEY,
    633	.arg4_type	= ARG_ANYTHING,
    634};
    635
    636BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
    637	   struct bpf_map *, map, u32, key, u64, flags)
    638{
    639	struct sock *sk;
    640
    641	if (unlikely(flags & ~(BPF_F_INGRESS)))
    642		return SK_DROP;
    643
    644	sk = __sock_map_lookup_elem(map, key);
    645	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
    646		return SK_DROP;
    647
    648	skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
    649	return SK_PASS;
    650}
    651
    652const struct bpf_func_proto bpf_sk_redirect_map_proto = {
    653	.func           = bpf_sk_redirect_map,
    654	.gpl_only       = false,
    655	.ret_type       = RET_INTEGER,
    656	.arg1_type	= ARG_PTR_TO_CTX,
    657	.arg2_type      = ARG_CONST_MAP_PTR,
    658	.arg3_type      = ARG_ANYTHING,
    659	.arg4_type      = ARG_ANYTHING,
    660};
    661
    662BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
    663	   struct bpf_map *, map, u32, key, u64, flags)
    664{
    665	struct sock *sk;
    666
    667	if (unlikely(flags & ~(BPF_F_INGRESS)))
    668		return SK_DROP;
    669
    670	sk = __sock_map_lookup_elem(map, key);
    671	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
    672		return SK_DROP;
    673
    674	msg->flags = flags;
    675	msg->sk_redir = sk;
    676	return SK_PASS;
    677}
    678
    679const struct bpf_func_proto bpf_msg_redirect_map_proto = {
    680	.func           = bpf_msg_redirect_map,
    681	.gpl_only       = false,
    682	.ret_type       = RET_INTEGER,
    683	.arg1_type	= ARG_PTR_TO_CTX,
    684	.arg2_type      = ARG_CONST_MAP_PTR,
    685	.arg3_type      = ARG_ANYTHING,
    686	.arg4_type      = ARG_ANYTHING,
    687};
    688
    689struct sock_map_seq_info {
    690	struct bpf_map *map;
    691	struct sock *sk;
    692	u32 index;
    693};
    694
    695struct bpf_iter__sockmap {
    696	__bpf_md_ptr(struct bpf_iter_meta *, meta);
    697	__bpf_md_ptr(struct bpf_map *, map);
    698	__bpf_md_ptr(void *, key);
    699	__bpf_md_ptr(struct sock *, sk);
    700};
    701
    702DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta,
    703		     struct bpf_map *map, void *key,
    704		     struct sock *sk)
    705
    706static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info)
    707{
    708	if (unlikely(info->index >= info->map->max_entries))
    709		return NULL;
    710
    711	info->sk = __sock_map_lookup_elem(info->map, info->index);
    712
    713	/* can't return sk directly, since that might be NULL */
    714	return info;
    715}
    716
    717static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos)
    718	__acquires(rcu)
    719{
    720	struct sock_map_seq_info *info = seq->private;
    721
    722	if (*pos == 0)
    723		++*pos;
    724
    725	/* pairs with sock_map_seq_stop */
    726	rcu_read_lock();
    727	return sock_map_seq_lookup_elem(info);
    728}
    729
    730static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos)
    731	__must_hold(rcu)
    732{
    733	struct sock_map_seq_info *info = seq->private;
    734
    735	++*pos;
    736	++info->index;
    737
    738	return sock_map_seq_lookup_elem(info);
    739}
    740
    741static int sock_map_seq_show(struct seq_file *seq, void *v)
    742	__must_hold(rcu)
    743{
    744	struct sock_map_seq_info *info = seq->private;
    745	struct bpf_iter__sockmap ctx = {};
    746	struct bpf_iter_meta meta;
    747	struct bpf_prog *prog;
    748
    749	meta.seq = seq;
    750	prog = bpf_iter_get_info(&meta, !v);
    751	if (!prog)
    752		return 0;
    753
    754	ctx.meta = &meta;
    755	ctx.map = info->map;
    756	if (v) {
    757		ctx.key = &info->index;
    758		ctx.sk = info->sk;
    759	}
    760
    761	return bpf_iter_run_prog(prog, &ctx);
    762}
    763
    764static void sock_map_seq_stop(struct seq_file *seq, void *v)
    765	__releases(rcu)
    766{
    767	if (!v)
    768		(void)sock_map_seq_show(seq, NULL);
    769
    770	/* pairs with sock_map_seq_start */
    771	rcu_read_unlock();
    772}
    773
    774static const struct seq_operations sock_map_seq_ops = {
    775	.start	= sock_map_seq_start,
    776	.next	= sock_map_seq_next,
    777	.stop	= sock_map_seq_stop,
    778	.show	= sock_map_seq_show,
    779};
    780
    781static int sock_map_init_seq_private(void *priv_data,
    782				     struct bpf_iter_aux_info *aux)
    783{
    784	struct sock_map_seq_info *info = priv_data;
    785
    786	info->map = aux->map;
    787	return 0;
    788}
    789
    790static const struct bpf_iter_seq_info sock_map_iter_seq_info = {
    791	.seq_ops		= &sock_map_seq_ops,
    792	.init_seq_private	= sock_map_init_seq_private,
    793	.seq_priv_size		= sizeof(struct sock_map_seq_info),
    794};
    795
    796BTF_ID_LIST_SINGLE(sock_map_btf_ids, struct, bpf_stab)
    797const struct bpf_map_ops sock_map_ops = {
    798	.map_meta_equal		= bpf_map_meta_equal,
    799	.map_alloc		= sock_map_alloc,
    800	.map_free		= sock_map_free,
    801	.map_get_next_key	= sock_map_get_next_key,
    802	.map_lookup_elem_sys_only = sock_map_lookup_sys,
    803	.map_update_elem	= sock_map_update_elem,
    804	.map_delete_elem	= sock_map_delete_elem,
    805	.map_lookup_elem	= sock_map_lookup,
    806	.map_release_uref	= sock_map_release_progs,
    807	.map_check_btf		= map_check_no_btf,
    808	.map_btf_id		= &sock_map_btf_ids[0],
    809	.iter_seq_info		= &sock_map_iter_seq_info,
    810};
    811
    812struct bpf_shtab_elem {
    813	struct rcu_head rcu;
    814	u32 hash;
    815	struct sock *sk;
    816	struct hlist_node node;
    817	u8 key[];
    818};
    819
    820struct bpf_shtab_bucket {
    821	struct hlist_head head;
    822	raw_spinlock_t lock;
    823};
    824
    825struct bpf_shtab {
    826	struct bpf_map map;
    827	struct bpf_shtab_bucket *buckets;
    828	u32 buckets_num;
    829	u32 elem_size;
    830	struct sk_psock_progs progs;
    831	atomic_t count;
    832};
    833
    834static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
    835{
    836	return jhash(key, len, 0);
    837}
    838
    839static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
    840							u32 hash)
    841{
    842	return &htab->buckets[hash & (htab->buckets_num - 1)];
    843}
    844
    845static struct bpf_shtab_elem *
    846sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
    847			  u32 key_size)
    848{
    849	struct bpf_shtab_elem *elem;
    850
    851	hlist_for_each_entry_rcu(elem, head, node) {
    852		if (elem->hash == hash &&
    853		    !memcmp(&elem->key, key, key_size))
    854			return elem;
    855	}
    856
    857	return NULL;
    858}
    859
    860static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
    861{
    862	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
    863	u32 key_size = map->key_size, hash;
    864	struct bpf_shtab_bucket *bucket;
    865	struct bpf_shtab_elem *elem;
    866
    867	WARN_ON_ONCE(!rcu_read_lock_held());
    868
    869	hash = sock_hash_bucket_hash(key, key_size);
    870	bucket = sock_hash_select_bucket(htab, hash);
    871	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
    872
    873	return elem ? elem->sk : NULL;
    874}
    875
    876static void sock_hash_free_elem(struct bpf_shtab *htab,
    877				struct bpf_shtab_elem *elem)
    878{
    879	atomic_dec(&htab->count);
    880	kfree_rcu(elem, rcu);
    881}
    882
    883static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
    884				       void *link_raw)
    885{
    886	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
    887	struct bpf_shtab_elem *elem_probe, *elem = link_raw;
    888	struct bpf_shtab_bucket *bucket;
    889
    890	WARN_ON_ONCE(!rcu_read_lock_held());
    891	bucket = sock_hash_select_bucket(htab, elem->hash);
    892
    893	/* elem may be deleted in parallel from the map, but access here
    894	 * is okay since it's going away only after RCU grace period.
    895	 * However, we need to check whether it's still present.
    896	 */
    897	raw_spin_lock_bh(&bucket->lock);
    898	elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
    899					       elem->key, map->key_size);
    900	if (elem_probe && elem_probe == elem) {
    901		hlist_del_rcu(&elem->node);
    902		sock_map_unref(elem->sk, elem);
    903		sock_hash_free_elem(htab, elem);
    904	}
    905	raw_spin_unlock_bh(&bucket->lock);
    906}
    907
    908static int sock_hash_delete_elem(struct bpf_map *map, void *key)
    909{
    910	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
    911	u32 hash, key_size = map->key_size;
    912	struct bpf_shtab_bucket *bucket;
    913	struct bpf_shtab_elem *elem;
    914	int ret = -ENOENT;
    915
    916	hash = sock_hash_bucket_hash(key, key_size);
    917	bucket = sock_hash_select_bucket(htab, hash);
    918
    919	raw_spin_lock_bh(&bucket->lock);
    920	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
    921	if (elem) {
    922		hlist_del_rcu(&elem->node);
    923		sock_map_unref(elem->sk, elem);
    924		sock_hash_free_elem(htab, elem);
    925		ret = 0;
    926	}
    927	raw_spin_unlock_bh(&bucket->lock);
    928	return ret;
    929}
    930
    931static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab,
    932						   void *key, u32 key_size,
    933						   u32 hash, struct sock *sk,
    934						   struct bpf_shtab_elem *old)
    935{
    936	struct bpf_shtab_elem *new;
    937
    938	if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
    939		if (!old) {
    940			atomic_dec(&htab->count);
    941			return ERR_PTR(-E2BIG);
    942		}
    943	}
    944
    945	new = bpf_map_kmalloc_node(&htab->map, htab->elem_size,
    946				   GFP_ATOMIC | __GFP_NOWARN,
    947				   htab->map.numa_node);
    948	if (!new) {
    949		atomic_dec(&htab->count);
    950		return ERR_PTR(-ENOMEM);
    951	}
    952	memcpy(new->key, key, key_size);
    953	new->sk = sk;
    954	new->hash = hash;
    955	return new;
    956}
    957
    958static int sock_hash_update_common(struct bpf_map *map, void *key,
    959				   struct sock *sk, u64 flags)
    960{
    961	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
    962	u32 key_size = map->key_size, hash;
    963	struct bpf_shtab_elem *elem, *elem_new;
    964	struct bpf_shtab_bucket *bucket;
    965	struct sk_psock_link *link;
    966	struct sk_psock *psock;
    967	int ret;
    968
    969	WARN_ON_ONCE(!rcu_read_lock_held());
    970	if (unlikely(flags > BPF_EXIST))
    971		return -EINVAL;
    972
    973	link = sk_psock_init_link();
    974	if (!link)
    975		return -ENOMEM;
    976
    977	ret = sock_map_link(map, sk);
    978	if (ret < 0)
    979		goto out_free;
    980
    981	psock = sk_psock(sk);
    982	WARN_ON_ONCE(!psock);
    983
    984	hash = sock_hash_bucket_hash(key, key_size);
    985	bucket = sock_hash_select_bucket(htab, hash);
    986
    987	raw_spin_lock_bh(&bucket->lock);
    988	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
    989	if (elem && flags == BPF_NOEXIST) {
    990		ret = -EEXIST;
    991		goto out_unlock;
    992	} else if (!elem && flags == BPF_EXIST) {
    993		ret = -ENOENT;
    994		goto out_unlock;
    995	}
    996
    997	elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
    998	if (IS_ERR(elem_new)) {
    999		ret = PTR_ERR(elem_new);
   1000		goto out_unlock;
   1001	}
   1002
   1003	sock_map_add_link(psock, link, map, elem_new);
   1004	/* Add new element to the head of the list, so that
   1005	 * concurrent search will find it before old elem.
   1006	 */
   1007	hlist_add_head_rcu(&elem_new->node, &bucket->head);
   1008	if (elem) {
   1009		hlist_del_rcu(&elem->node);
   1010		sock_map_unref(elem->sk, elem);
   1011		sock_hash_free_elem(htab, elem);
   1012	}
   1013	raw_spin_unlock_bh(&bucket->lock);
   1014	return 0;
   1015out_unlock:
   1016	raw_spin_unlock_bh(&bucket->lock);
   1017	sk_psock_put(sk, psock);
   1018out_free:
   1019	sk_psock_free_link(link);
   1020	return ret;
   1021}
   1022
   1023static int sock_hash_get_next_key(struct bpf_map *map, void *key,
   1024				  void *key_next)
   1025{
   1026	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
   1027	struct bpf_shtab_elem *elem, *elem_next;
   1028	u32 hash, key_size = map->key_size;
   1029	struct hlist_head *head;
   1030	int i = 0;
   1031
   1032	if (!key)
   1033		goto find_first_elem;
   1034	hash = sock_hash_bucket_hash(key, key_size);
   1035	head = &sock_hash_select_bucket(htab, hash)->head;
   1036	elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
   1037	if (!elem)
   1038		goto find_first_elem;
   1039
   1040	elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)),
   1041				     struct bpf_shtab_elem, node);
   1042	if (elem_next) {
   1043		memcpy(key_next, elem_next->key, key_size);
   1044		return 0;
   1045	}
   1046
   1047	i = hash & (htab->buckets_num - 1);
   1048	i++;
   1049find_first_elem:
   1050	for (; i < htab->buckets_num; i++) {
   1051		head = &sock_hash_select_bucket(htab, i)->head;
   1052		elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)),
   1053					     struct bpf_shtab_elem, node);
   1054		if (elem_next) {
   1055			memcpy(key_next, elem_next->key, key_size);
   1056			return 0;
   1057		}
   1058	}
   1059
   1060	return -ENOENT;
   1061}
   1062
   1063static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
   1064{
   1065	struct bpf_shtab *htab;
   1066	int i, err;
   1067
   1068	if (!capable(CAP_NET_ADMIN))
   1069		return ERR_PTR(-EPERM);
   1070	if (attr->max_entries == 0 ||
   1071	    attr->key_size    == 0 ||
   1072	    (attr->value_size != sizeof(u32) &&
   1073	     attr->value_size != sizeof(u64)) ||
   1074	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
   1075		return ERR_PTR(-EINVAL);
   1076	if (attr->key_size > MAX_BPF_STACK)
   1077		return ERR_PTR(-E2BIG);
   1078
   1079	htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT);
   1080	if (!htab)
   1081		return ERR_PTR(-ENOMEM);
   1082
   1083	bpf_map_init_from_attr(&htab->map, attr);
   1084
   1085	htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
   1086	htab->elem_size = sizeof(struct bpf_shtab_elem) +
   1087			  round_up(htab->map.key_size, 8);
   1088	if (htab->buckets_num == 0 ||
   1089	    htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) {
   1090		err = -EINVAL;
   1091		goto free_htab;
   1092	}
   1093
   1094	htab->buckets = bpf_map_area_alloc(htab->buckets_num *
   1095					   sizeof(struct bpf_shtab_bucket),
   1096					   htab->map.numa_node);
   1097	if (!htab->buckets) {
   1098		err = -ENOMEM;
   1099		goto free_htab;
   1100	}
   1101
   1102	for (i = 0; i < htab->buckets_num; i++) {
   1103		INIT_HLIST_HEAD(&htab->buckets[i].head);
   1104		raw_spin_lock_init(&htab->buckets[i].lock);
   1105	}
   1106
   1107	return &htab->map;
   1108free_htab:
   1109	kfree(htab);
   1110	return ERR_PTR(err);
   1111}
   1112
   1113static void sock_hash_free(struct bpf_map *map)
   1114{
   1115	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
   1116	struct bpf_shtab_bucket *bucket;
   1117	struct hlist_head unlink_list;
   1118	struct bpf_shtab_elem *elem;
   1119	struct hlist_node *node;
   1120	int i;
   1121
   1122	/* After the sync no updates or deletes will be in-flight so it
   1123	 * is safe to walk map and remove entries without risking a race
   1124	 * in EEXIST update case.
   1125	 */
   1126	synchronize_rcu();
   1127	for (i = 0; i < htab->buckets_num; i++) {
   1128		bucket = sock_hash_select_bucket(htab, i);
   1129
   1130		/* We are racing with sock_hash_delete_from_link to
   1131		 * enter the spin-lock critical section. Every socket on
   1132		 * the list is still linked to sockhash. Since link
   1133		 * exists, psock exists and holds a ref to socket. That
   1134		 * lets us to grab a socket ref too.
   1135		 */
   1136		raw_spin_lock_bh(&bucket->lock);
   1137		hlist_for_each_entry(elem, &bucket->head, node)
   1138			sock_hold(elem->sk);
   1139		hlist_move_list(&bucket->head, &unlink_list);
   1140		raw_spin_unlock_bh(&bucket->lock);
   1141
   1142		/* Process removed entries out of atomic context to
   1143		 * block for socket lock before deleting the psock's
   1144		 * link to sockhash.
   1145		 */
   1146		hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
   1147			hlist_del(&elem->node);
   1148			lock_sock(elem->sk);
   1149			rcu_read_lock();
   1150			sock_map_unref(elem->sk, elem);
   1151			rcu_read_unlock();
   1152			release_sock(elem->sk);
   1153			sock_put(elem->sk);
   1154			sock_hash_free_elem(htab, elem);
   1155		}
   1156	}
   1157
   1158	/* wait for psock readers accessing its map link */
   1159	synchronize_rcu();
   1160
   1161	bpf_map_area_free(htab->buckets);
   1162	kfree(htab);
   1163}
   1164
   1165static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
   1166{
   1167	struct sock *sk;
   1168
   1169	if (map->value_size != sizeof(u64))
   1170		return ERR_PTR(-ENOSPC);
   1171
   1172	sk = __sock_hash_lookup_elem(map, key);
   1173	if (!sk)
   1174		return ERR_PTR(-ENOENT);
   1175
   1176	__sock_gen_cookie(sk);
   1177	return &sk->sk_cookie;
   1178}
   1179
   1180static void *sock_hash_lookup(struct bpf_map *map, void *key)
   1181{
   1182	struct sock *sk;
   1183
   1184	sk = __sock_hash_lookup_elem(map, key);
   1185	if (!sk)
   1186		return NULL;
   1187	if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
   1188		return NULL;
   1189	return sk;
   1190}
   1191
   1192static void sock_hash_release_progs(struct bpf_map *map)
   1193{
   1194	psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs);
   1195}
   1196
   1197BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
   1198	   struct bpf_map *, map, void *, key, u64, flags)
   1199{
   1200	WARN_ON_ONCE(!rcu_read_lock_held());
   1201
   1202	if (likely(sock_map_sk_is_suitable(sops->sk) &&
   1203		   sock_map_op_okay(sops)))
   1204		return sock_hash_update_common(map, key, sops->sk, flags);
   1205	return -EOPNOTSUPP;
   1206}
   1207
   1208const struct bpf_func_proto bpf_sock_hash_update_proto = {
   1209	.func		= bpf_sock_hash_update,
   1210	.gpl_only	= false,
   1211	.pkt_access	= true,
   1212	.ret_type	= RET_INTEGER,
   1213	.arg1_type	= ARG_PTR_TO_CTX,
   1214	.arg2_type	= ARG_CONST_MAP_PTR,
   1215	.arg3_type	= ARG_PTR_TO_MAP_KEY,
   1216	.arg4_type	= ARG_ANYTHING,
   1217};
   1218
   1219BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
   1220	   struct bpf_map *, map, void *, key, u64, flags)
   1221{
   1222	struct sock *sk;
   1223
   1224	if (unlikely(flags & ~(BPF_F_INGRESS)))
   1225		return SK_DROP;
   1226
   1227	sk = __sock_hash_lookup_elem(map, key);
   1228	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
   1229		return SK_DROP;
   1230
   1231	skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
   1232	return SK_PASS;
   1233}
   1234
   1235const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
   1236	.func           = bpf_sk_redirect_hash,
   1237	.gpl_only       = false,
   1238	.ret_type       = RET_INTEGER,
   1239	.arg1_type	= ARG_PTR_TO_CTX,
   1240	.arg2_type      = ARG_CONST_MAP_PTR,
   1241	.arg3_type      = ARG_PTR_TO_MAP_KEY,
   1242	.arg4_type      = ARG_ANYTHING,
   1243};
   1244
   1245BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
   1246	   struct bpf_map *, map, void *, key, u64, flags)
   1247{
   1248	struct sock *sk;
   1249
   1250	if (unlikely(flags & ~(BPF_F_INGRESS)))
   1251		return SK_DROP;
   1252
   1253	sk = __sock_hash_lookup_elem(map, key);
   1254	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
   1255		return SK_DROP;
   1256
   1257	msg->flags = flags;
   1258	msg->sk_redir = sk;
   1259	return SK_PASS;
   1260}
   1261
   1262const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
   1263	.func           = bpf_msg_redirect_hash,
   1264	.gpl_only       = false,
   1265	.ret_type       = RET_INTEGER,
   1266	.arg1_type	= ARG_PTR_TO_CTX,
   1267	.arg2_type      = ARG_CONST_MAP_PTR,
   1268	.arg3_type      = ARG_PTR_TO_MAP_KEY,
   1269	.arg4_type      = ARG_ANYTHING,
   1270};
   1271
   1272struct sock_hash_seq_info {
   1273	struct bpf_map *map;
   1274	struct bpf_shtab *htab;
   1275	u32 bucket_id;
   1276};
   1277
   1278static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info,
   1279				     struct bpf_shtab_elem *prev_elem)
   1280{
   1281	const struct bpf_shtab *htab = info->htab;
   1282	struct bpf_shtab_bucket *bucket;
   1283	struct bpf_shtab_elem *elem;
   1284	struct hlist_node *node;
   1285
   1286	/* try to find next elem in the same bucket */
   1287	if (prev_elem) {
   1288		node = rcu_dereference(hlist_next_rcu(&prev_elem->node));
   1289		elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
   1290		if (elem)
   1291			return elem;
   1292
   1293		/* no more elements, continue in the next bucket */
   1294		info->bucket_id++;
   1295	}
   1296
   1297	for (; info->bucket_id < htab->buckets_num; info->bucket_id++) {
   1298		bucket = &htab->buckets[info->bucket_id];
   1299		node = rcu_dereference(hlist_first_rcu(&bucket->head));
   1300		elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
   1301		if (elem)
   1302			return elem;
   1303	}
   1304
   1305	return NULL;
   1306}
   1307
   1308static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos)
   1309	__acquires(rcu)
   1310{
   1311	struct sock_hash_seq_info *info = seq->private;
   1312
   1313	if (*pos == 0)
   1314		++*pos;
   1315
   1316	/* pairs with sock_hash_seq_stop */
   1317	rcu_read_lock();
   1318	return sock_hash_seq_find_next(info, NULL);
   1319}
   1320
   1321static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos)
   1322	__must_hold(rcu)
   1323{
   1324	struct sock_hash_seq_info *info = seq->private;
   1325
   1326	++*pos;
   1327	return sock_hash_seq_find_next(info, v);
   1328}
   1329
   1330static int sock_hash_seq_show(struct seq_file *seq, void *v)
   1331	__must_hold(rcu)
   1332{
   1333	struct sock_hash_seq_info *info = seq->private;
   1334	struct bpf_iter__sockmap ctx = {};
   1335	struct bpf_shtab_elem *elem = v;
   1336	struct bpf_iter_meta meta;
   1337	struct bpf_prog *prog;
   1338
   1339	meta.seq = seq;
   1340	prog = bpf_iter_get_info(&meta, !elem);
   1341	if (!prog)
   1342		return 0;
   1343
   1344	ctx.meta = &meta;
   1345	ctx.map = info->map;
   1346	if (elem) {
   1347		ctx.key = elem->key;
   1348		ctx.sk = elem->sk;
   1349	}
   1350
   1351	return bpf_iter_run_prog(prog, &ctx);
   1352}
   1353
   1354static void sock_hash_seq_stop(struct seq_file *seq, void *v)
   1355	__releases(rcu)
   1356{
   1357	if (!v)
   1358		(void)sock_hash_seq_show(seq, NULL);
   1359
   1360	/* pairs with sock_hash_seq_start */
   1361	rcu_read_unlock();
   1362}
   1363
   1364static const struct seq_operations sock_hash_seq_ops = {
   1365	.start	= sock_hash_seq_start,
   1366	.next	= sock_hash_seq_next,
   1367	.stop	= sock_hash_seq_stop,
   1368	.show	= sock_hash_seq_show,
   1369};
   1370
   1371static int sock_hash_init_seq_private(void *priv_data,
   1372				     struct bpf_iter_aux_info *aux)
   1373{
   1374	struct sock_hash_seq_info *info = priv_data;
   1375
   1376	info->map = aux->map;
   1377	info->htab = container_of(aux->map, struct bpf_shtab, map);
   1378	return 0;
   1379}
   1380
   1381static const struct bpf_iter_seq_info sock_hash_iter_seq_info = {
   1382	.seq_ops		= &sock_hash_seq_ops,
   1383	.init_seq_private	= sock_hash_init_seq_private,
   1384	.seq_priv_size		= sizeof(struct sock_hash_seq_info),
   1385};
   1386
   1387BTF_ID_LIST_SINGLE(sock_hash_map_btf_ids, struct, bpf_shtab)
   1388const struct bpf_map_ops sock_hash_ops = {
   1389	.map_meta_equal		= bpf_map_meta_equal,
   1390	.map_alloc		= sock_hash_alloc,
   1391	.map_free		= sock_hash_free,
   1392	.map_get_next_key	= sock_hash_get_next_key,
   1393	.map_update_elem	= sock_map_update_elem,
   1394	.map_delete_elem	= sock_hash_delete_elem,
   1395	.map_lookup_elem	= sock_hash_lookup,
   1396	.map_lookup_elem_sys_only = sock_hash_lookup_sys,
   1397	.map_release_uref	= sock_hash_release_progs,
   1398	.map_check_btf		= map_check_no_btf,
   1399	.map_btf_id		= &sock_hash_map_btf_ids[0],
   1400	.iter_seq_info		= &sock_hash_iter_seq_info,
   1401};
   1402
   1403static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
   1404{
   1405	switch (map->map_type) {
   1406	case BPF_MAP_TYPE_SOCKMAP:
   1407		return &container_of(map, struct bpf_stab, map)->progs;
   1408	case BPF_MAP_TYPE_SOCKHASH:
   1409		return &container_of(map, struct bpf_shtab, map)->progs;
   1410	default:
   1411		break;
   1412	}
   1413
   1414	return NULL;
   1415}
   1416
   1417static int sock_map_prog_lookup(struct bpf_map *map, struct bpf_prog ***pprog,
   1418				u32 which)
   1419{
   1420	struct sk_psock_progs *progs = sock_map_progs(map);
   1421
   1422	if (!progs)
   1423		return -EOPNOTSUPP;
   1424
   1425	switch (which) {
   1426	case BPF_SK_MSG_VERDICT:
   1427		*pprog = &progs->msg_parser;
   1428		break;
   1429#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
   1430	case BPF_SK_SKB_STREAM_PARSER:
   1431		*pprog = &progs->stream_parser;
   1432		break;
   1433#endif
   1434	case BPF_SK_SKB_STREAM_VERDICT:
   1435		if (progs->skb_verdict)
   1436			return -EBUSY;
   1437		*pprog = &progs->stream_verdict;
   1438		break;
   1439	case BPF_SK_SKB_VERDICT:
   1440		if (progs->stream_verdict)
   1441			return -EBUSY;
   1442		*pprog = &progs->skb_verdict;
   1443		break;
   1444	default:
   1445		return -EOPNOTSUPP;
   1446	}
   1447
   1448	return 0;
   1449}
   1450
   1451static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
   1452				struct bpf_prog *old, u32 which)
   1453{
   1454	struct bpf_prog **pprog;
   1455	int ret;
   1456
   1457	ret = sock_map_prog_lookup(map, &pprog, which);
   1458	if (ret)
   1459		return ret;
   1460
   1461	if (old)
   1462		return psock_replace_prog(pprog, prog, old);
   1463
   1464	psock_set_prog(pprog, prog);
   1465	return 0;
   1466}
   1467
   1468int sock_map_bpf_prog_query(const union bpf_attr *attr,
   1469			    union bpf_attr __user *uattr)
   1470{
   1471	__u32 __user *prog_ids = u64_to_user_ptr(attr->query.prog_ids);
   1472	u32 prog_cnt = 0, flags = 0, ufd = attr->target_fd;
   1473	struct bpf_prog **pprog;
   1474	struct bpf_prog *prog;
   1475	struct bpf_map *map;
   1476	struct fd f;
   1477	u32 id = 0;
   1478	int ret;
   1479
   1480	if (attr->query.query_flags)
   1481		return -EINVAL;
   1482
   1483	f = fdget(ufd);
   1484	map = __bpf_map_get(f);
   1485	if (IS_ERR(map))
   1486		return PTR_ERR(map);
   1487
   1488	rcu_read_lock();
   1489
   1490	ret = sock_map_prog_lookup(map, &pprog, attr->query.attach_type);
   1491	if (ret)
   1492		goto end;
   1493
   1494	prog = *pprog;
   1495	prog_cnt = !prog ? 0 : 1;
   1496
   1497	if (!attr->query.prog_cnt || !prog_ids || !prog_cnt)
   1498		goto end;
   1499
   1500	/* we do not hold the refcnt, the bpf prog may be released
   1501	 * asynchronously and the id would be set to 0.
   1502	 */
   1503	id = data_race(prog->aux->id);
   1504	if (id == 0)
   1505		prog_cnt = 0;
   1506
   1507end:
   1508	rcu_read_unlock();
   1509
   1510	if (copy_to_user(&uattr->query.attach_flags, &flags, sizeof(flags)) ||
   1511	    (id != 0 && copy_to_user(prog_ids, &id, sizeof(u32))) ||
   1512	    copy_to_user(&uattr->query.prog_cnt, &prog_cnt, sizeof(prog_cnt)))
   1513		ret = -EFAULT;
   1514
   1515	fdput(f);
   1516	return ret;
   1517}
   1518
   1519static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
   1520{
   1521	switch (link->map->map_type) {
   1522	case BPF_MAP_TYPE_SOCKMAP:
   1523		return sock_map_delete_from_link(link->map, sk,
   1524						 link->link_raw);
   1525	case BPF_MAP_TYPE_SOCKHASH:
   1526		return sock_hash_delete_from_link(link->map, sk,
   1527						  link->link_raw);
   1528	default:
   1529		break;
   1530	}
   1531}
   1532
   1533static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
   1534{
   1535	struct sk_psock_link *link;
   1536
   1537	while ((link = sk_psock_link_pop(psock))) {
   1538		sock_map_unlink(sk, link);
   1539		sk_psock_free_link(link);
   1540	}
   1541}
   1542
   1543void sock_map_unhash(struct sock *sk)
   1544{
   1545	void (*saved_unhash)(struct sock *sk);
   1546	struct sk_psock *psock;
   1547
   1548	rcu_read_lock();
   1549	psock = sk_psock(sk);
   1550	if (unlikely(!psock)) {
   1551		rcu_read_unlock();
   1552		if (sk->sk_prot->unhash)
   1553			sk->sk_prot->unhash(sk);
   1554		return;
   1555	}
   1556
   1557	saved_unhash = psock->saved_unhash;
   1558	sock_map_remove_links(sk, psock);
   1559	rcu_read_unlock();
   1560	saved_unhash(sk);
   1561}
   1562EXPORT_SYMBOL_GPL(sock_map_unhash);
   1563
   1564void sock_map_close(struct sock *sk, long timeout)
   1565{
   1566	void (*saved_close)(struct sock *sk, long timeout);
   1567	struct sk_psock *psock;
   1568
   1569	lock_sock(sk);
   1570	rcu_read_lock();
   1571	psock = sk_psock_get(sk);
   1572	if (unlikely(!psock)) {
   1573		rcu_read_unlock();
   1574		release_sock(sk);
   1575		return sk->sk_prot->close(sk, timeout);
   1576	}
   1577
   1578	saved_close = psock->saved_close;
   1579	sock_map_remove_links(sk, psock);
   1580	rcu_read_unlock();
   1581	sk_psock_stop(psock, true);
   1582	sk_psock_put(sk, psock);
   1583	release_sock(sk);
   1584	saved_close(sk, timeout);
   1585}
   1586EXPORT_SYMBOL_GPL(sock_map_close);
   1587
   1588static int sock_map_iter_attach_target(struct bpf_prog *prog,
   1589				       union bpf_iter_link_info *linfo,
   1590				       struct bpf_iter_aux_info *aux)
   1591{
   1592	struct bpf_map *map;
   1593	int err = -EINVAL;
   1594
   1595	if (!linfo->map.map_fd)
   1596		return -EBADF;
   1597
   1598	map = bpf_map_get_with_uref(linfo->map.map_fd);
   1599	if (IS_ERR(map))
   1600		return PTR_ERR(map);
   1601
   1602	if (map->map_type != BPF_MAP_TYPE_SOCKMAP &&
   1603	    map->map_type != BPF_MAP_TYPE_SOCKHASH)
   1604		goto put_map;
   1605
   1606	if (prog->aux->max_rdonly_access > map->key_size) {
   1607		err = -EACCES;
   1608		goto put_map;
   1609	}
   1610
   1611	aux->map = map;
   1612	return 0;
   1613
   1614put_map:
   1615	bpf_map_put_with_uref(map);
   1616	return err;
   1617}
   1618
   1619static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux)
   1620{
   1621	bpf_map_put_with_uref(aux->map);
   1622}
   1623
   1624static struct bpf_iter_reg sock_map_iter_reg = {
   1625	.target			= "sockmap",
   1626	.attach_target		= sock_map_iter_attach_target,
   1627	.detach_target		= sock_map_iter_detach_target,
   1628	.show_fdinfo		= bpf_iter_map_show_fdinfo,
   1629	.fill_link_info		= bpf_iter_map_fill_link_info,
   1630	.ctx_arg_info_size	= 2,
   1631	.ctx_arg_info		= {
   1632		{ offsetof(struct bpf_iter__sockmap, key),
   1633		  PTR_TO_BUF | PTR_MAYBE_NULL | MEM_RDONLY },
   1634		{ offsetof(struct bpf_iter__sockmap, sk),
   1635		  PTR_TO_BTF_ID_OR_NULL },
   1636	},
   1637};
   1638
   1639static int __init bpf_sockmap_iter_init(void)
   1640{
   1641	sock_map_iter_reg.ctx_arg_info[1].btf_id =
   1642		btf_sock_ids[BTF_SOCK_TYPE_SOCK];
   1643	return bpf_iter_reg_target(&sock_map_iter_reg);
   1644}
   1645late_initcall(bpf_sockmap_iter_init);