cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

ratelimiter.c (5991B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
      4 */
      5
      6#include "ratelimiter.h"
      7#include <linux/siphash.h>
      8#include <linux/mm.h>
      9#include <linux/slab.h>
     10#include <net/ip.h>
     11
     12static struct kmem_cache *entry_cache;
     13static hsiphash_key_t key;
     14static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
     15static DEFINE_MUTEX(init_lock);
     16static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
     17static atomic_t total_entries = ATOMIC_INIT(0);
     18static unsigned int max_entries, table_size;
     19static void wg_ratelimiter_gc_entries(struct work_struct *);
     20static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
     21static struct hlist_head *table_v4;
     22#if IS_ENABLED(CONFIG_IPV6)
     23static struct hlist_head *table_v6;
     24#endif
     25
     26struct ratelimiter_entry {
     27	u64 last_time_ns, tokens, ip;
     28	void *net;
     29	spinlock_t lock;
     30	struct hlist_node hash;
     31	struct rcu_head rcu;
     32};
     33
     34enum {
     35	PACKETS_PER_SECOND = 20,
     36	PACKETS_BURSTABLE = 5,
     37	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
     38	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
     39};
     40
     41static void entry_free(struct rcu_head *rcu)
     42{
     43	kmem_cache_free(entry_cache,
     44			container_of(rcu, struct ratelimiter_entry, rcu));
     45	atomic_dec(&total_entries);
     46}
     47
     48static void entry_uninit(struct ratelimiter_entry *entry)
     49{
     50	hlist_del_rcu(&entry->hash);
     51	call_rcu(&entry->rcu, entry_free);
     52}
     53
     54/* Calling this function with a NULL work uninits all entries. */
     55static void wg_ratelimiter_gc_entries(struct work_struct *work)
     56{
     57	const u64 now = ktime_get_coarse_boottime_ns();
     58	struct ratelimiter_entry *entry;
     59	struct hlist_node *temp;
     60	unsigned int i;
     61
     62	for (i = 0; i < table_size; ++i) {
     63		spin_lock(&table_lock);
     64		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
     65			if (unlikely(!work) ||
     66			    now - entry->last_time_ns > NSEC_PER_SEC)
     67				entry_uninit(entry);
     68		}
     69#if IS_ENABLED(CONFIG_IPV6)
     70		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
     71			if (unlikely(!work) ||
     72			    now - entry->last_time_ns > NSEC_PER_SEC)
     73				entry_uninit(entry);
     74		}
     75#endif
     76		spin_unlock(&table_lock);
     77		if (likely(work))
     78			cond_resched();
     79	}
     80	if (likely(work))
     81		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
     82}
     83
     84bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
     85{
     86	/* We only take the bottom half of the net pointer, so that we can hash
     87	 * 3 words in the end. This way, siphash's len param fits into the final
     88	 * u32, and we don't incur an extra round.
     89	 */
     90	const u32 net_word = (unsigned long)net;
     91	struct ratelimiter_entry *entry;
     92	struct hlist_head *bucket;
     93	u64 ip;
     94
     95	if (skb->protocol == htons(ETH_P_IP)) {
     96		ip = (u64 __force)ip_hdr(skb)->saddr;
     97		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
     98				   (table_size - 1)];
     99	}
    100#if IS_ENABLED(CONFIG_IPV6)
    101	else if (skb->protocol == htons(ETH_P_IPV6)) {
    102		/* Only use 64 bits, so as to ratelimit the whole /64. */
    103		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
    104		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
    105				   (table_size - 1)];
    106	}
    107#endif
    108	else
    109		return false;
    110	rcu_read_lock();
    111	hlist_for_each_entry_rcu(entry, bucket, hash) {
    112		if (entry->net == net && entry->ip == ip) {
    113			u64 now, tokens;
    114			bool ret;
    115			/* Quasi-inspired by nft_limit.c, but this is actually a
    116			 * slightly different algorithm. Namely, we incorporate
    117			 * the burst as part of the maximum tokens, rather than
    118			 * as part of the rate.
    119			 */
    120			spin_lock(&entry->lock);
    121			now = ktime_get_coarse_boottime_ns();
    122			tokens = min_t(u64, TOKEN_MAX,
    123				       entry->tokens + now -
    124					       entry->last_time_ns);
    125			entry->last_time_ns = now;
    126			ret = tokens >= PACKET_COST;
    127			entry->tokens = ret ? tokens - PACKET_COST : tokens;
    128			spin_unlock(&entry->lock);
    129			rcu_read_unlock();
    130			return ret;
    131		}
    132	}
    133	rcu_read_unlock();
    134
    135	if (atomic_inc_return(&total_entries) > max_entries)
    136		goto err_oom;
    137
    138	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
    139	if (unlikely(!entry))
    140		goto err_oom;
    141
    142	entry->net = net;
    143	entry->ip = ip;
    144	INIT_HLIST_NODE(&entry->hash);
    145	spin_lock_init(&entry->lock);
    146	entry->last_time_ns = ktime_get_coarse_boottime_ns();
    147	entry->tokens = TOKEN_MAX - PACKET_COST;
    148	spin_lock(&table_lock);
    149	hlist_add_head_rcu(&entry->hash, bucket);
    150	spin_unlock(&table_lock);
    151	return true;
    152
    153err_oom:
    154	atomic_dec(&total_entries);
    155	return false;
    156}
    157
    158int wg_ratelimiter_init(void)
    159{
    160	mutex_lock(&init_lock);
    161	if (++init_refcnt != 1)
    162		goto out;
    163
    164	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
    165	if (!entry_cache)
    166		goto err;
    167
    168	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
    169	 * but what it shares in common is that it uses a massive hashtable. So,
    170	 * we borrow their wisdom about good table sizes on different systems
    171	 * dependent on RAM. This calculation here comes from there.
    172	 */
    173	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
    174		max_t(unsigned long, 16, roundup_pow_of_two(
    175			(totalram_pages() << PAGE_SHIFT) /
    176			(1U << 14) / sizeof(struct hlist_head)));
    177	max_entries = table_size * 8;
    178
    179	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
    180	if (unlikely(!table_v4))
    181		goto err_kmemcache;
    182
    183#if IS_ENABLED(CONFIG_IPV6)
    184	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
    185	if (unlikely(!table_v6)) {
    186		kvfree(table_v4);
    187		goto err_kmemcache;
    188	}
    189#endif
    190
    191	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
    192	get_random_bytes(&key, sizeof(key));
    193out:
    194	mutex_unlock(&init_lock);
    195	return 0;
    196
    197err_kmemcache:
    198	kmem_cache_destroy(entry_cache);
    199err:
    200	--init_refcnt;
    201	mutex_unlock(&init_lock);
    202	return -ENOMEM;
    203}
    204
    205void wg_ratelimiter_uninit(void)
    206{
    207	mutex_lock(&init_lock);
    208	if (!init_refcnt || --init_refcnt)
    209		goto out;
    210
    211	cancel_delayed_work_sync(&gc_work);
    212	wg_ratelimiter_gc_entries(NULL);
    213	rcu_barrier();
    214	kvfree(table_v4);
    215#if IS_ENABLED(CONFIG_IPV6)
    216	kvfree(table_v6);
    217#endif
    218	kmem_cache_destroy(entry_cache);
    219out:
    220	mutex_unlock(&init_lock);
    221}
    222
    223#include "selftest/ratelimiter.c"