cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

bpf_sk_storage.c (23942B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Copyright (c) 2019 Facebook  */
      3#include <linux/rculist.h>
      4#include <linux/list.h>
      5#include <linux/hash.h>
      6#include <linux/types.h>
      7#include <linux/spinlock.h>
      8#include <linux/bpf.h>
      9#include <linux/btf.h>
     10#include <linux/btf_ids.h>
     11#include <linux/bpf_local_storage.h>
     12#include <net/bpf_sk_storage.h>
     13#include <net/sock.h>
     14#include <uapi/linux/sock_diag.h>
     15#include <uapi/linux/btf.h>
     16#include <linux/rcupdate_trace.h>
     17
     18DEFINE_BPF_STORAGE_CACHE(sk_cache);
     19
     20static struct bpf_local_storage_data *
     21bpf_sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
     22{
     23	struct bpf_local_storage *sk_storage;
     24	struct bpf_local_storage_map *smap;
     25
     26	sk_storage =
     27		rcu_dereference_check(sk->sk_bpf_storage, bpf_rcu_lock_held());
     28	if (!sk_storage)
     29		return NULL;
     30
     31	smap = (struct bpf_local_storage_map *)map;
     32	return bpf_local_storage_lookup(sk_storage, smap, cacheit_lockit);
     33}
     34
     35static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map)
     36{
     37	struct bpf_local_storage_data *sdata;
     38
     39	sdata = bpf_sk_storage_lookup(sk, map, false);
     40	if (!sdata)
     41		return -ENOENT;
     42
     43	bpf_selem_unlink(SELEM(sdata), true);
     44
     45	return 0;
     46}
     47
     48/* Called by __sk_destruct() & bpf_sk_storage_clone() */
     49void bpf_sk_storage_free(struct sock *sk)
     50{
     51	struct bpf_local_storage_elem *selem;
     52	struct bpf_local_storage *sk_storage;
     53	bool free_sk_storage = false;
     54	struct hlist_node *n;
     55
     56	rcu_read_lock();
     57	sk_storage = rcu_dereference(sk->sk_bpf_storage);
     58	if (!sk_storage) {
     59		rcu_read_unlock();
     60		return;
     61	}
     62
     63	/* Netiher the bpf_prog nor the bpf-map's syscall
     64	 * could be modifying the sk_storage->list now.
     65	 * Thus, no elem can be added-to or deleted-from the
     66	 * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
     67	 *
     68	 * It is racing with bpf_local_storage_map_free() alone
     69	 * when unlinking elem from the sk_storage->list and
     70	 * the map's bucket->list.
     71	 */
     72	raw_spin_lock_bh(&sk_storage->lock);
     73	hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
     74		/* Always unlink from map before unlinking from
     75		 * sk_storage.
     76		 */
     77		bpf_selem_unlink_map(selem);
     78		free_sk_storage = bpf_selem_unlink_storage_nolock(
     79			sk_storage, selem, true, false);
     80	}
     81	raw_spin_unlock_bh(&sk_storage->lock);
     82	rcu_read_unlock();
     83
     84	if (free_sk_storage)
     85		kfree_rcu(sk_storage, rcu);
     86}
     87
     88static void bpf_sk_storage_map_free(struct bpf_map *map)
     89{
     90	struct bpf_local_storage_map *smap;
     91
     92	smap = (struct bpf_local_storage_map *)map;
     93	bpf_local_storage_cache_idx_free(&sk_cache, smap->cache_idx);
     94	bpf_local_storage_map_free(smap, NULL);
     95}
     96
     97static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
     98{
     99	struct bpf_local_storage_map *smap;
    100
    101	smap = bpf_local_storage_map_alloc(attr);
    102	if (IS_ERR(smap))
    103		return ERR_CAST(smap);
    104
    105	smap->cache_idx = bpf_local_storage_cache_idx_get(&sk_cache);
    106	return &smap->map;
    107}
    108
    109static int notsupp_get_next_key(struct bpf_map *map, void *key,
    110				void *next_key)
    111{
    112	return -ENOTSUPP;
    113}
    114
    115static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
    116{
    117	struct bpf_local_storage_data *sdata;
    118	struct socket *sock;
    119	int fd, err;
    120
    121	fd = *(int *)key;
    122	sock = sockfd_lookup(fd, &err);
    123	if (sock) {
    124		sdata = bpf_sk_storage_lookup(sock->sk, map, true);
    125		sockfd_put(sock);
    126		return sdata ? sdata->data : NULL;
    127	}
    128
    129	return ERR_PTR(err);
    130}
    131
    132static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
    133					 void *value, u64 map_flags)
    134{
    135	struct bpf_local_storage_data *sdata;
    136	struct socket *sock;
    137	int fd, err;
    138
    139	fd = *(int *)key;
    140	sock = sockfd_lookup(fd, &err);
    141	if (sock) {
    142		sdata = bpf_local_storage_update(
    143			sock->sk, (struct bpf_local_storage_map *)map, value,
    144			map_flags, GFP_ATOMIC);
    145		sockfd_put(sock);
    146		return PTR_ERR_OR_ZERO(sdata);
    147	}
    148
    149	return err;
    150}
    151
    152static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
    153{
    154	struct socket *sock;
    155	int fd, err;
    156
    157	fd = *(int *)key;
    158	sock = sockfd_lookup(fd, &err);
    159	if (sock) {
    160		err = bpf_sk_storage_del(sock->sk, map);
    161		sockfd_put(sock);
    162		return err;
    163	}
    164
    165	return err;
    166}
    167
    168static struct bpf_local_storage_elem *
    169bpf_sk_storage_clone_elem(struct sock *newsk,
    170			  struct bpf_local_storage_map *smap,
    171			  struct bpf_local_storage_elem *selem)
    172{
    173	struct bpf_local_storage_elem *copy_selem;
    174
    175	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
    176	if (!copy_selem)
    177		return NULL;
    178
    179	if (map_value_has_spin_lock(&smap->map))
    180		copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
    181				      SDATA(selem)->data, true);
    182	else
    183		copy_map_value(&smap->map, SDATA(copy_selem)->data,
    184			       SDATA(selem)->data);
    185
    186	return copy_selem;
    187}
    188
    189int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
    190{
    191	struct bpf_local_storage *new_sk_storage = NULL;
    192	struct bpf_local_storage *sk_storage;
    193	struct bpf_local_storage_elem *selem;
    194	int ret = 0;
    195
    196	RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
    197
    198	rcu_read_lock();
    199	sk_storage = rcu_dereference(sk->sk_bpf_storage);
    200
    201	if (!sk_storage || hlist_empty(&sk_storage->list))
    202		goto out;
    203
    204	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
    205		struct bpf_local_storage_elem *copy_selem;
    206		struct bpf_local_storage_map *smap;
    207		struct bpf_map *map;
    208
    209		smap = rcu_dereference(SDATA(selem)->smap);
    210		if (!(smap->map.map_flags & BPF_F_CLONE))
    211			continue;
    212
    213		/* Note that for lockless listeners adding new element
    214		 * here can race with cleanup in bpf_local_storage_map_free.
    215		 * Try to grab map refcnt to make sure that it's still
    216		 * alive and prevent concurrent removal.
    217		 */
    218		map = bpf_map_inc_not_zero(&smap->map);
    219		if (IS_ERR(map))
    220			continue;
    221
    222		copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
    223		if (!copy_selem) {
    224			ret = -ENOMEM;
    225			bpf_map_put(map);
    226			goto out;
    227		}
    228
    229		if (new_sk_storage) {
    230			bpf_selem_link_map(smap, copy_selem);
    231			bpf_selem_link_storage_nolock(new_sk_storage, copy_selem);
    232		} else {
    233			ret = bpf_local_storage_alloc(newsk, smap, copy_selem, GFP_ATOMIC);
    234			if (ret) {
    235				kfree(copy_selem);
    236				atomic_sub(smap->elem_size,
    237					   &newsk->sk_omem_alloc);
    238				bpf_map_put(map);
    239				goto out;
    240			}
    241
    242			new_sk_storage =
    243				rcu_dereference(copy_selem->local_storage);
    244		}
    245		bpf_map_put(map);
    246	}
    247
    248out:
    249	rcu_read_unlock();
    250
    251	/* In case of an error, don't free anything explicitly here, the
    252	 * caller is responsible to call bpf_sk_storage_free.
    253	 */
    254
    255	return ret;
    256}
    257
    258/* *gfp_flags* is a hidden argument provided by the verifier */
    259BPF_CALL_5(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
    260	   void *, value, u64, flags, gfp_t, gfp_flags)
    261{
    262	struct bpf_local_storage_data *sdata;
    263
    264	WARN_ON_ONCE(!bpf_rcu_lock_held());
    265	if (!sk || !sk_fullsock(sk) || flags > BPF_SK_STORAGE_GET_F_CREATE)
    266		return (unsigned long)NULL;
    267
    268	sdata = bpf_sk_storage_lookup(sk, map, true);
    269	if (sdata)
    270		return (unsigned long)sdata->data;
    271
    272	if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
    273	    /* Cannot add new elem to a going away sk.
    274	     * Otherwise, the new elem may become a leak
    275	     * (and also other memory issues during map
    276	     *  destruction).
    277	     */
    278	    refcount_inc_not_zero(&sk->sk_refcnt)) {
    279		sdata = bpf_local_storage_update(
    280			sk, (struct bpf_local_storage_map *)map, value,
    281			BPF_NOEXIST, gfp_flags);
    282		/* sk must be a fullsock (guaranteed by verifier),
    283		 * so sock_gen_put() is unnecessary.
    284		 */
    285		sock_put(sk);
    286		return IS_ERR(sdata) ?
    287			(unsigned long)NULL : (unsigned long)sdata->data;
    288	}
    289
    290	return (unsigned long)NULL;
    291}
    292
    293BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
    294{
    295	WARN_ON_ONCE(!bpf_rcu_lock_held());
    296	if (!sk || !sk_fullsock(sk))
    297		return -EINVAL;
    298
    299	if (refcount_inc_not_zero(&sk->sk_refcnt)) {
    300		int err;
    301
    302		err = bpf_sk_storage_del(sk, map);
    303		sock_put(sk);
    304		return err;
    305	}
    306
    307	return -ENOENT;
    308}
    309
    310static int bpf_sk_storage_charge(struct bpf_local_storage_map *smap,
    311				 void *owner, u32 size)
    312{
    313	struct sock *sk = (struct sock *)owner;
    314
    315	/* same check as in sock_kmalloc() */
    316	if (size <= sysctl_optmem_max &&
    317	    atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
    318		atomic_add(size, &sk->sk_omem_alloc);
    319		return 0;
    320	}
    321
    322	return -ENOMEM;
    323}
    324
    325static void bpf_sk_storage_uncharge(struct bpf_local_storage_map *smap,
    326				    void *owner, u32 size)
    327{
    328	struct sock *sk = owner;
    329
    330	atomic_sub(size, &sk->sk_omem_alloc);
    331}
    332
    333static struct bpf_local_storage __rcu **
    334bpf_sk_storage_ptr(void *owner)
    335{
    336	struct sock *sk = owner;
    337
    338	return &sk->sk_bpf_storage;
    339}
    340
    341BTF_ID_LIST_SINGLE(sk_storage_map_btf_ids, struct, bpf_local_storage_map)
    342const struct bpf_map_ops sk_storage_map_ops = {
    343	.map_meta_equal = bpf_map_meta_equal,
    344	.map_alloc_check = bpf_local_storage_map_alloc_check,
    345	.map_alloc = bpf_sk_storage_map_alloc,
    346	.map_free = bpf_sk_storage_map_free,
    347	.map_get_next_key = notsupp_get_next_key,
    348	.map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
    349	.map_update_elem = bpf_fd_sk_storage_update_elem,
    350	.map_delete_elem = bpf_fd_sk_storage_delete_elem,
    351	.map_check_btf = bpf_local_storage_map_check_btf,
    352	.map_btf_id = &sk_storage_map_btf_ids[0],
    353	.map_local_storage_charge = bpf_sk_storage_charge,
    354	.map_local_storage_uncharge = bpf_sk_storage_uncharge,
    355	.map_owner_storage_ptr = bpf_sk_storage_ptr,
    356};
    357
    358const struct bpf_func_proto bpf_sk_storage_get_proto = {
    359	.func		= bpf_sk_storage_get,
    360	.gpl_only	= false,
    361	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
    362	.arg1_type	= ARG_CONST_MAP_PTR,
    363	.arg2_type	= ARG_PTR_TO_BTF_ID_SOCK_COMMON,
    364	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
    365	.arg4_type	= ARG_ANYTHING,
    366};
    367
    368const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto = {
    369	.func		= bpf_sk_storage_get,
    370	.gpl_only	= false,
    371	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
    372	.arg1_type	= ARG_CONST_MAP_PTR,
    373	.arg2_type	= ARG_PTR_TO_CTX, /* context is 'struct sock' */
    374	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
    375	.arg4_type	= ARG_ANYTHING,
    376};
    377
    378const struct bpf_func_proto bpf_sk_storage_delete_proto = {
    379	.func		= bpf_sk_storage_delete,
    380	.gpl_only	= false,
    381	.ret_type	= RET_INTEGER,
    382	.arg1_type	= ARG_CONST_MAP_PTR,
    383	.arg2_type	= ARG_PTR_TO_BTF_ID_SOCK_COMMON,
    384};
    385
    386static bool bpf_sk_storage_tracing_allowed(const struct bpf_prog *prog)
    387{
    388	const struct btf *btf_vmlinux;
    389	const struct btf_type *t;
    390	const char *tname;
    391	u32 btf_id;
    392
    393	if (prog->aux->dst_prog)
    394		return false;
    395
    396	/* Ensure the tracing program is not tracing
    397	 * any bpf_sk_storage*() function and also
    398	 * use the bpf_sk_storage_(get|delete) helper.
    399	 */
    400	switch (prog->expected_attach_type) {
    401	case BPF_TRACE_ITER:
    402	case BPF_TRACE_RAW_TP:
    403		/* bpf_sk_storage has no trace point */
    404		return true;
    405	case BPF_TRACE_FENTRY:
    406	case BPF_TRACE_FEXIT:
    407		btf_vmlinux = bpf_get_btf_vmlinux();
    408		if (IS_ERR_OR_NULL(btf_vmlinux))
    409			return false;
    410		btf_id = prog->aux->attach_btf_id;
    411		t = btf_type_by_id(btf_vmlinux, btf_id);
    412		tname = btf_name_by_offset(btf_vmlinux, t->name_off);
    413		return !!strncmp(tname, "bpf_sk_storage",
    414				 strlen("bpf_sk_storage"));
    415	default:
    416		return false;
    417	}
    418
    419	return false;
    420}
    421
    422/* *gfp_flags* is a hidden argument provided by the verifier */
    423BPF_CALL_5(bpf_sk_storage_get_tracing, struct bpf_map *, map, struct sock *, sk,
    424	   void *, value, u64, flags, gfp_t, gfp_flags)
    425{
    426	WARN_ON_ONCE(!bpf_rcu_lock_held());
    427	if (in_hardirq() || in_nmi())
    428		return (unsigned long)NULL;
    429
    430	return (unsigned long)____bpf_sk_storage_get(map, sk, value, flags,
    431						     gfp_flags);
    432}
    433
    434BPF_CALL_2(bpf_sk_storage_delete_tracing, struct bpf_map *, map,
    435	   struct sock *, sk)
    436{
    437	WARN_ON_ONCE(!bpf_rcu_lock_held());
    438	if (in_hardirq() || in_nmi())
    439		return -EPERM;
    440
    441	return ____bpf_sk_storage_delete(map, sk);
    442}
    443
    444const struct bpf_func_proto bpf_sk_storage_get_tracing_proto = {
    445	.func		= bpf_sk_storage_get_tracing,
    446	.gpl_only	= false,
    447	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
    448	.arg1_type	= ARG_CONST_MAP_PTR,
    449	.arg2_type	= ARG_PTR_TO_BTF_ID,
    450	.arg2_btf_id	= &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
    451	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
    452	.arg4_type	= ARG_ANYTHING,
    453	.allowed	= bpf_sk_storage_tracing_allowed,
    454};
    455
    456const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto = {
    457	.func		= bpf_sk_storage_delete_tracing,
    458	.gpl_only	= false,
    459	.ret_type	= RET_INTEGER,
    460	.arg1_type	= ARG_CONST_MAP_PTR,
    461	.arg2_type	= ARG_PTR_TO_BTF_ID,
    462	.arg2_btf_id	= &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
    463	.allowed	= bpf_sk_storage_tracing_allowed,
    464};
    465
    466struct bpf_sk_storage_diag {
    467	u32 nr_maps;
    468	struct bpf_map *maps[];
    469};
    470
    471/* The reply will be like:
    472 * INET_DIAG_BPF_SK_STORAGES (nla_nest)
    473 *	SK_DIAG_BPF_STORAGE (nla_nest)
    474 *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
    475 *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
    476 *	SK_DIAG_BPF_STORAGE (nla_nest)
    477 *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
    478 *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
    479 *	....
    480 */
    481static int nla_value_size(u32 value_size)
    482{
    483	/* SK_DIAG_BPF_STORAGE (nla_nest)
    484	 *	SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
    485	 *	SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
    486	 */
    487	return nla_total_size(0) + nla_total_size(sizeof(u32)) +
    488		nla_total_size_64bit(value_size);
    489}
    490
    491void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
    492{
    493	u32 i;
    494
    495	if (!diag)
    496		return;
    497
    498	for (i = 0; i < diag->nr_maps; i++)
    499		bpf_map_put(diag->maps[i]);
    500
    501	kfree(diag);
    502}
    503EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
    504
    505static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
    506			   const struct bpf_map *map)
    507{
    508	u32 i;
    509
    510	for (i = 0; i < diag->nr_maps; i++) {
    511		if (diag->maps[i] == map)
    512			return true;
    513	}
    514
    515	return false;
    516}
    517
    518struct bpf_sk_storage_diag *
    519bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
    520{
    521	struct bpf_sk_storage_diag *diag;
    522	struct nlattr *nla;
    523	u32 nr_maps = 0;
    524	int rem, err;
    525
    526	/* bpf_local_storage_map is currently limited to CAP_SYS_ADMIN as
    527	 * the map_alloc_check() side also does.
    528	 */
    529	if (!bpf_capable())
    530		return ERR_PTR(-EPERM);
    531
    532	nla_for_each_nested(nla, nla_stgs, rem) {
    533		if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
    534			nr_maps++;
    535	}
    536
    537	diag = kzalloc(struct_size(diag, maps, nr_maps), GFP_KERNEL);
    538	if (!diag)
    539		return ERR_PTR(-ENOMEM);
    540
    541	nla_for_each_nested(nla, nla_stgs, rem) {
    542		struct bpf_map *map;
    543		int map_fd;
    544
    545		if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
    546			continue;
    547
    548		map_fd = nla_get_u32(nla);
    549		map = bpf_map_get(map_fd);
    550		if (IS_ERR(map)) {
    551			err = PTR_ERR(map);
    552			goto err_free;
    553		}
    554		if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
    555			bpf_map_put(map);
    556			err = -EINVAL;
    557			goto err_free;
    558		}
    559		if (diag_check_dup(diag, map)) {
    560			bpf_map_put(map);
    561			err = -EEXIST;
    562			goto err_free;
    563		}
    564		diag->maps[diag->nr_maps++] = map;
    565	}
    566
    567	return diag;
    568
    569err_free:
    570	bpf_sk_storage_diag_free(diag);
    571	return ERR_PTR(err);
    572}
    573EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
    574
    575static int diag_get(struct bpf_local_storage_data *sdata, struct sk_buff *skb)
    576{
    577	struct nlattr *nla_stg, *nla_value;
    578	struct bpf_local_storage_map *smap;
    579
    580	/* It cannot exceed max nlattr's payload */
    581	BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < BPF_LOCAL_STORAGE_MAX_VALUE_SIZE);
    582
    583	nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
    584	if (!nla_stg)
    585		return -EMSGSIZE;
    586
    587	smap = rcu_dereference(sdata->smap);
    588	if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
    589		goto errout;
    590
    591	nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
    592				      smap->map.value_size,
    593				      SK_DIAG_BPF_STORAGE_PAD);
    594	if (!nla_value)
    595		goto errout;
    596
    597	if (map_value_has_spin_lock(&smap->map))
    598		copy_map_value_locked(&smap->map, nla_data(nla_value),
    599				      sdata->data, true);
    600	else
    601		copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
    602
    603	nla_nest_end(skb, nla_stg);
    604	return 0;
    605
    606errout:
    607	nla_nest_cancel(skb, nla_stg);
    608	return -EMSGSIZE;
    609}
    610
    611static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
    612				       int stg_array_type,
    613				       unsigned int *res_diag_size)
    614{
    615	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
    616	unsigned int diag_size = nla_total_size(0);
    617	struct bpf_local_storage *sk_storage;
    618	struct bpf_local_storage_elem *selem;
    619	struct bpf_local_storage_map *smap;
    620	struct nlattr *nla_stgs;
    621	unsigned int saved_len;
    622	int err = 0;
    623
    624	rcu_read_lock();
    625
    626	sk_storage = rcu_dereference(sk->sk_bpf_storage);
    627	if (!sk_storage || hlist_empty(&sk_storage->list)) {
    628		rcu_read_unlock();
    629		return 0;
    630	}
    631
    632	nla_stgs = nla_nest_start(skb, stg_array_type);
    633	if (!nla_stgs)
    634		/* Continue to learn diag_size */
    635		err = -EMSGSIZE;
    636
    637	saved_len = skb->len;
    638	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
    639		smap = rcu_dereference(SDATA(selem)->smap);
    640		diag_size += nla_value_size(smap->map.value_size);
    641
    642		if (nla_stgs && diag_get(SDATA(selem), skb))
    643			/* Continue to learn diag_size */
    644			err = -EMSGSIZE;
    645	}
    646
    647	rcu_read_unlock();
    648
    649	if (nla_stgs) {
    650		if (saved_len == skb->len)
    651			nla_nest_cancel(skb, nla_stgs);
    652		else
    653			nla_nest_end(skb, nla_stgs);
    654	}
    655
    656	if (diag_size == nla_total_size(0)) {
    657		*res_diag_size = 0;
    658		return 0;
    659	}
    660
    661	*res_diag_size = diag_size;
    662	return err;
    663}
    664
    665int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
    666			    struct sock *sk, struct sk_buff *skb,
    667			    int stg_array_type,
    668			    unsigned int *res_diag_size)
    669{
    670	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
    671	unsigned int diag_size = nla_total_size(0);
    672	struct bpf_local_storage *sk_storage;
    673	struct bpf_local_storage_data *sdata;
    674	struct nlattr *nla_stgs;
    675	unsigned int saved_len;
    676	int err = 0;
    677	u32 i;
    678
    679	*res_diag_size = 0;
    680
    681	/* No map has been specified.  Dump all. */
    682	if (!diag->nr_maps)
    683		return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
    684						   res_diag_size);
    685
    686	rcu_read_lock();
    687	sk_storage = rcu_dereference(sk->sk_bpf_storage);
    688	if (!sk_storage || hlist_empty(&sk_storage->list)) {
    689		rcu_read_unlock();
    690		return 0;
    691	}
    692
    693	nla_stgs = nla_nest_start(skb, stg_array_type);
    694	if (!nla_stgs)
    695		/* Continue to learn diag_size */
    696		err = -EMSGSIZE;
    697
    698	saved_len = skb->len;
    699	for (i = 0; i < diag->nr_maps; i++) {
    700		sdata = bpf_local_storage_lookup(sk_storage,
    701				(struct bpf_local_storage_map *)diag->maps[i],
    702				false);
    703
    704		if (!sdata)
    705			continue;
    706
    707		diag_size += nla_value_size(diag->maps[i]->value_size);
    708
    709		if (nla_stgs && diag_get(sdata, skb))
    710			/* Continue to learn diag_size */
    711			err = -EMSGSIZE;
    712	}
    713	rcu_read_unlock();
    714
    715	if (nla_stgs) {
    716		if (saved_len == skb->len)
    717			nla_nest_cancel(skb, nla_stgs);
    718		else
    719			nla_nest_end(skb, nla_stgs);
    720	}
    721
    722	if (diag_size == nla_total_size(0)) {
    723		*res_diag_size = 0;
    724		return 0;
    725	}
    726
    727	*res_diag_size = diag_size;
    728	return err;
    729}
    730EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
    731
    732struct bpf_iter_seq_sk_storage_map_info {
    733	struct bpf_map *map;
    734	unsigned int bucket_id;
    735	unsigned skip_elems;
    736};
    737
    738static struct bpf_local_storage_elem *
    739bpf_sk_storage_map_seq_find_next(struct bpf_iter_seq_sk_storage_map_info *info,
    740				 struct bpf_local_storage_elem *prev_selem)
    741	__acquires(RCU) __releases(RCU)
    742{
    743	struct bpf_local_storage *sk_storage;
    744	struct bpf_local_storage_elem *selem;
    745	u32 skip_elems = info->skip_elems;
    746	struct bpf_local_storage_map *smap;
    747	u32 bucket_id = info->bucket_id;
    748	u32 i, count, n_buckets;
    749	struct bpf_local_storage_map_bucket *b;
    750
    751	smap = (struct bpf_local_storage_map *)info->map;
    752	n_buckets = 1U << smap->bucket_log;
    753	if (bucket_id >= n_buckets)
    754		return NULL;
    755
    756	/* try to find next selem in the same bucket */
    757	selem = prev_selem;
    758	count = 0;
    759	while (selem) {
    760		selem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&selem->map_node)),
    761					 struct bpf_local_storage_elem, map_node);
    762		if (!selem) {
    763			/* not found, unlock and go to the next bucket */
    764			b = &smap->buckets[bucket_id++];
    765			rcu_read_unlock();
    766			skip_elems = 0;
    767			break;
    768		}
    769		sk_storage = rcu_dereference(selem->local_storage);
    770		if (sk_storage) {
    771			info->skip_elems = skip_elems + count;
    772			return selem;
    773		}
    774		count++;
    775	}
    776
    777	for (i = bucket_id; i < (1U << smap->bucket_log); i++) {
    778		b = &smap->buckets[i];
    779		rcu_read_lock();
    780		count = 0;
    781		hlist_for_each_entry_rcu(selem, &b->list, map_node) {
    782			sk_storage = rcu_dereference(selem->local_storage);
    783			if (sk_storage && count >= skip_elems) {
    784				info->bucket_id = i;
    785				info->skip_elems = count;
    786				return selem;
    787			}
    788			count++;
    789		}
    790		rcu_read_unlock();
    791		skip_elems = 0;
    792	}
    793
    794	info->bucket_id = i;
    795	info->skip_elems = 0;
    796	return NULL;
    797}
    798
    799static void *bpf_sk_storage_map_seq_start(struct seq_file *seq, loff_t *pos)
    800{
    801	struct bpf_local_storage_elem *selem;
    802
    803	selem = bpf_sk_storage_map_seq_find_next(seq->private, NULL);
    804	if (!selem)
    805		return NULL;
    806
    807	if (*pos == 0)
    808		++*pos;
    809	return selem;
    810}
    811
    812static void *bpf_sk_storage_map_seq_next(struct seq_file *seq, void *v,
    813					 loff_t *pos)
    814{
    815	struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
    816
    817	++*pos;
    818	++info->skip_elems;
    819	return bpf_sk_storage_map_seq_find_next(seq->private, v);
    820}
    821
    822struct bpf_iter__bpf_sk_storage_map {
    823	__bpf_md_ptr(struct bpf_iter_meta *, meta);
    824	__bpf_md_ptr(struct bpf_map *, map);
    825	__bpf_md_ptr(struct sock *, sk);
    826	__bpf_md_ptr(void *, value);
    827};
    828
    829DEFINE_BPF_ITER_FUNC(bpf_sk_storage_map, struct bpf_iter_meta *meta,
    830		     struct bpf_map *map, struct sock *sk,
    831		     void *value)
    832
    833static int __bpf_sk_storage_map_seq_show(struct seq_file *seq,
    834					 struct bpf_local_storage_elem *selem)
    835{
    836	struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
    837	struct bpf_iter__bpf_sk_storage_map ctx = {};
    838	struct bpf_local_storage *sk_storage;
    839	struct bpf_iter_meta meta;
    840	struct bpf_prog *prog;
    841	int ret = 0;
    842
    843	meta.seq = seq;
    844	prog = bpf_iter_get_info(&meta, selem == NULL);
    845	if (prog) {
    846		ctx.meta = &meta;
    847		ctx.map = info->map;
    848		if (selem) {
    849			sk_storage = rcu_dereference(selem->local_storage);
    850			ctx.sk = sk_storage->owner;
    851			ctx.value = SDATA(selem)->data;
    852		}
    853		ret = bpf_iter_run_prog(prog, &ctx);
    854	}
    855
    856	return ret;
    857}
    858
    859static int bpf_sk_storage_map_seq_show(struct seq_file *seq, void *v)
    860{
    861	return __bpf_sk_storage_map_seq_show(seq, v);
    862}
    863
    864static void bpf_sk_storage_map_seq_stop(struct seq_file *seq, void *v)
    865	__releases(RCU)
    866{
    867	if (!v)
    868		(void)__bpf_sk_storage_map_seq_show(seq, v);
    869	else
    870		rcu_read_unlock();
    871}
    872
    873static int bpf_iter_init_sk_storage_map(void *priv_data,
    874					struct bpf_iter_aux_info *aux)
    875{
    876	struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data;
    877
    878	seq_info->map = aux->map;
    879	return 0;
    880}
    881
    882static int bpf_iter_attach_map(struct bpf_prog *prog,
    883			       union bpf_iter_link_info *linfo,
    884			       struct bpf_iter_aux_info *aux)
    885{
    886	struct bpf_map *map;
    887	int err = -EINVAL;
    888
    889	if (!linfo->map.map_fd)
    890		return -EBADF;
    891
    892	map = bpf_map_get_with_uref(linfo->map.map_fd);
    893	if (IS_ERR(map))
    894		return PTR_ERR(map);
    895
    896	if (map->map_type != BPF_MAP_TYPE_SK_STORAGE)
    897		goto put_map;
    898
    899	if (prog->aux->max_rdonly_access > map->value_size) {
    900		err = -EACCES;
    901		goto put_map;
    902	}
    903
    904	aux->map = map;
    905	return 0;
    906
    907put_map:
    908	bpf_map_put_with_uref(map);
    909	return err;
    910}
    911
    912static void bpf_iter_detach_map(struct bpf_iter_aux_info *aux)
    913{
    914	bpf_map_put_with_uref(aux->map);
    915}
    916
    917static const struct seq_operations bpf_sk_storage_map_seq_ops = {
    918	.start  = bpf_sk_storage_map_seq_start,
    919	.next   = bpf_sk_storage_map_seq_next,
    920	.stop   = bpf_sk_storage_map_seq_stop,
    921	.show   = bpf_sk_storage_map_seq_show,
    922};
    923
    924static const struct bpf_iter_seq_info iter_seq_info = {
    925	.seq_ops		= &bpf_sk_storage_map_seq_ops,
    926	.init_seq_private	= bpf_iter_init_sk_storage_map,
    927	.fini_seq_private	= NULL,
    928	.seq_priv_size		= sizeof(struct bpf_iter_seq_sk_storage_map_info),
    929};
    930
    931static struct bpf_iter_reg bpf_sk_storage_map_reg_info = {
    932	.target			= "bpf_sk_storage_map",
    933	.attach_target		= bpf_iter_attach_map,
    934	.detach_target		= bpf_iter_detach_map,
    935	.show_fdinfo		= bpf_iter_map_show_fdinfo,
    936	.fill_link_info		= bpf_iter_map_fill_link_info,
    937	.ctx_arg_info_size	= 2,
    938	.ctx_arg_info		= {
    939		{ offsetof(struct bpf_iter__bpf_sk_storage_map, sk),
    940		  PTR_TO_BTF_ID_OR_NULL },
    941		{ offsetof(struct bpf_iter__bpf_sk_storage_map, value),
    942		  PTR_TO_BUF | PTR_MAYBE_NULL },
    943	},
    944	.seq_info		= &iter_seq_info,
    945};
    946
    947static int __init bpf_sk_storage_map_iter_init(void)
    948{
    949	bpf_sk_storage_map_reg_info.ctx_arg_info[0].btf_id =
    950		btf_sock_ids[BTF_SOCK_TYPE_SOCK];
    951	return bpf_iter_reg_target(&bpf_sk_storage_map_reg_info);
    952}
    953late_initcall(bpf_sk_storage_map_iter_init);