cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

act_sample.c (9583B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * net/sched/act_sample.c - Packet sampling tc action
      4 * Copyright (c) 2017 Yotam Gigi <yotamg@mellanox.com>
      5 */
      6
      7#include <linux/types.h>
      8#include <linux/kernel.h>
      9#include <linux/string.h>
     10#include <linux/errno.h>
     11#include <linux/skbuff.h>
     12#include <linux/rtnetlink.h>
     13#include <linux/module.h>
     14#include <linux/init.h>
     15#include <linux/gfp.h>
     16#include <net/net_namespace.h>
     17#include <net/netlink.h>
     18#include <net/pkt_sched.h>
     19#include <linux/tc_act/tc_sample.h>
     20#include <net/tc_act/tc_sample.h>
     21#include <net/psample.h>
     22#include <net/pkt_cls.h>
     23
     24#include <linux/if_arp.h>
     25
     26static unsigned int sample_net_id;
     27static struct tc_action_ops act_sample_ops;
     28
     29static const struct nla_policy sample_policy[TCA_SAMPLE_MAX + 1] = {
     30	[TCA_SAMPLE_PARMS]		= { .len = sizeof(struct tc_sample) },
     31	[TCA_SAMPLE_RATE]		= { .type = NLA_U32 },
     32	[TCA_SAMPLE_TRUNC_SIZE]		= { .type = NLA_U32 },
     33	[TCA_SAMPLE_PSAMPLE_GROUP]	= { .type = NLA_U32 },
     34};
     35
     36static int tcf_sample_init(struct net *net, struct nlattr *nla,
     37			   struct nlattr *est, struct tc_action **a,
     38			   struct tcf_proto *tp,
     39			   u32 flags, struct netlink_ext_ack *extack)
     40{
     41	struct tc_action_net *tn = net_generic(net, sample_net_id);
     42	bool bind = flags & TCA_ACT_FLAGS_BIND;
     43	struct nlattr *tb[TCA_SAMPLE_MAX + 1];
     44	struct psample_group *psample_group;
     45	u32 psample_group_num, rate, index;
     46	struct tcf_chain *goto_ch = NULL;
     47	struct tc_sample *parm;
     48	struct tcf_sample *s;
     49	bool exists = false;
     50	int ret, err;
     51
     52	if (!nla)
     53		return -EINVAL;
     54	ret = nla_parse_nested_deprecated(tb, TCA_SAMPLE_MAX, nla,
     55					  sample_policy, NULL);
     56	if (ret < 0)
     57		return ret;
     58	if (!tb[TCA_SAMPLE_PARMS] || !tb[TCA_SAMPLE_RATE] ||
     59	    !tb[TCA_SAMPLE_PSAMPLE_GROUP])
     60		return -EINVAL;
     61
     62	parm = nla_data(tb[TCA_SAMPLE_PARMS]);
     63	index = parm->index;
     64	err = tcf_idr_check_alloc(tn, &index, a, bind);
     65	if (err < 0)
     66		return err;
     67	exists = err;
     68	if (exists && bind)
     69		return 0;
     70
     71	if (!exists) {
     72		ret = tcf_idr_create(tn, index, est, a,
     73				     &act_sample_ops, bind, true, flags);
     74		if (ret) {
     75			tcf_idr_cleanup(tn, index);
     76			return ret;
     77		}
     78		ret = ACT_P_CREATED;
     79	} else if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
     80		tcf_idr_release(*a, bind);
     81		return -EEXIST;
     82	}
     83	err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
     84	if (err < 0)
     85		goto release_idr;
     86
     87	rate = nla_get_u32(tb[TCA_SAMPLE_RATE]);
     88	if (!rate) {
     89		NL_SET_ERR_MSG(extack, "invalid sample rate");
     90		err = -EINVAL;
     91		goto put_chain;
     92	}
     93	psample_group_num = nla_get_u32(tb[TCA_SAMPLE_PSAMPLE_GROUP]);
     94	psample_group = psample_group_get(net, psample_group_num);
     95	if (!psample_group) {
     96		err = -ENOMEM;
     97		goto put_chain;
     98	}
     99
    100	s = to_sample(*a);
    101
    102	spin_lock_bh(&s->tcf_lock);
    103	goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
    104	s->rate = rate;
    105	s->psample_group_num = psample_group_num;
    106	psample_group = rcu_replace_pointer(s->psample_group, psample_group,
    107					    lockdep_is_held(&s->tcf_lock));
    108
    109	if (tb[TCA_SAMPLE_TRUNC_SIZE]) {
    110		s->truncate = true;
    111		s->trunc_size = nla_get_u32(tb[TCA_SAMPLE_TRUNC_SIZE]);
    112	}
    113	spin_unlock_bh(&s->tcf_lock);
    114
    115	if (psample_group)
    116		psample_group_put(psample_group);
    117	if (goto_ch)
    118		tcf_chain_put_by_act(goto_ch);
    119
    120	return ret;
    121put_chain:
    122	if (goto_ch)
    123		tcf_chain_put_by_act(goto_ch);
    124release_idr:
    125	tcf_idr_release(*a, bind);
    126	return err;
    127}
    128
    129static void tcf_sample_cleanup(struct tc_action *a)
    130{
    131	struct tcf_sample *s = to_sample(a);
    132	struct psample_group *psample_group;
    133
    134	/* last reference to action, no need to lock */
    135	psample_group = rcu_dereference_protected(s->psample_group, 1);
    136	RCU_INIT_POINTER(s->psample_group, NULL);
    137	if (psample_group)
    138		psample_group_put(psample_group);
    139}
    140
    141static bool tcf_sample_dev_ok_push(struct net_device *dev)
    142{
    143	switch (dev->type) {
    144	case ARPHRD_TUNNEL:
    145	case ARPHRD_TUNNEL6:
    146	case ARPHRD_SIT:
    147	case ARPHRD_IPGRE:
    148	case ARPHRD_IP6GRE:
    149	case ARPHRD_VOID:
    150	case ARPHRD_NONE:
    151		return false;
    152	default:
    153		return true;
    154	}
    155}
    156
    157static int tcf_sample_act(struct sk_buff *skb, const struct tc_action *a,
    158			  struct tcf_result *res)
    159{
    160	struct tcf_sample *s = to_sample(a);
    161	struct psample_group *psample_group;
    162	struct psample_metadata md = {};
    163	int retval;
    164
    165	tcf_lastuse_update(&s->tcf_tm);
    166	bstats_update(this_cpu_ptr(s->common.cpu_bstats), skb);
    167	retval = READ_ONCE(s->tcf_action);
    168
    169	psample_group = rcu_dereference_bh(s->psample_group);
    170
    171	/* randomly sample packets according to rate */
    172	if (psample_group && (prandom_u32() % s->rate == 0)) {
    173		if (!skb_at_tc_ingress(skb)) {
    174			md.in_ifindex = skb->skb_iif;
    175			md.out_ifindex = skb->dev->ifindex;
    176		} else {
    177			md.in_ifindex = skb->dev->ifindex;
    178		}
    179
    180		/* on ingress, the mac header gets popped, so push it back */
    181		if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev))
    182			skb_push(skb, skb->mac_len);
    183
    184		md.trunc_size = s->truncate ? s->trunc_size : skb->len;
    185		psample_sample_packet(psample_group, skb, s->rate, &md);
    186
    187		if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev))
    188			skb_pull(skb, skb->mac_len);
    189	}
    190
    191	return retval;
    192}
    193
    194static void tcf_sample_stats_update(struct tc_action *a, u64 bytes, u64 packets,
    195				    u64 drops, u64 lastuse, bool hw)
    196{
    197	struct tcf_sample *s = to_sample(a);
    198	struct tcf_t *tm = &s->tcf_tm;
    199
    200	tcf_action_update_stats(a, bytes, packets, drops, hw);
    201	tm->lastuse = max_t(u64, tm->lastuse, lastuse);
    202}
    203
    204static int tcf_sample_dump(struct sk_buff *skb, struct tc_action *a,
    205			   int bind, int ref)
    206{
    207	unsigned char *b = skb_tail_pointer(skb);
    208	struct tcf_sample *s = to_sample(a);
    209	struct tc_sample opt = {
    210		.index      = s->tcf_index,
    211		.refcnt     = refcount_read(&s->tcf_refcnt) - ref,
    212		.bindcnt    = atomic_read(&s->tcf_bindcnt) - bind,
    213	};
    214	struct tcf_t t;
    215
    216	spin_lock_bh(&s->tcf_lock);
    217	opt.action = s->tcf_action;
    218	if (nla_put(skb, TCA_SAMPLE_PARMS, sizeof(opt), &opt))
    219		goto nla_put_failure;
    220
    221	tcf_tm_dump(&t, &s->tcf_tm);
    222	if (nla_put_64bit(skb, TCA_SAMPLE_TM, sizeof(t), &t, TCA_SAMPLE_PAD))
    223		goto nla_put_failure;
    224
    225	if (nla_put_u32(skb, TCA_SAMPLE_RATE, s->rate))
    226		goto nla_put_failure;
    227
    228	if (s->truncate)
    229		if (nla_put_u32(skb, TCA_SAMPLE_TRUNC_SIZE, s->trunc_size))
    230			goto nla_put_failure;
    231
    232	if (nla_put_u32(skb, TCA_SAMPLE_PSAMPLE_GROUP, s->psample_group_num))
    233		goto nla_put_failure;
    234	spin_unlock_bh(&s->tcf_lock);
    235
    236	return skb->len;
    237
    238nla_put_failure:
    239	spin_unlock_bh(&s->tcf_lock);
    240	nlmsg_trim(skb, b);
    241	return -1;
    242}
    243
    244static int tcf_sample_walker(struct net *net, struct sk_buff *skb,
    245			     struct netlink_callback *cb, int type,
    246			     const struct tc_action_ops *ops,
    247			     struct netlink_ext_ack *extack)
    248{
    249	struct tc_action_net *tn = net_generic(net, sample_net_id);
    250
    251	return tcf_generic_walker(tn, skb, cb, type, ops, extack);
    252}
    253
    254static int tcf_sample_search(struct net *net, struct tc_action **a, u32 index)
    255{
    256	struct tc_action_net *tn = net_generic(net, sample_net_id);
    257
    258	return tcf_idr_search(tn, a, index);
    259}
    260
    261static void tcf_psample_group_put(void *priv)
    262{
    263	struct psample_group *group = priv;
    264
    265	psample_group_put(group);
    266}
    267
    268static struct psample_group *
    269tcf_sample_get_group(const struct tc_action *a,
    270		     tc_action_priv_destructor *destructor)
    271{
    272	struct tcf_sample *s = to_sample(a);
    273	struct psample_group *group;
    274
    275	group = rcu_dereference_protected(s->psample_group,
    276					  lockdep_is_held(&s->tcf_lock));
    277	if (group) {
    278		psample_group_take(group);
    279		*destructor = tcf_psample_group_put;
    280	}
    281
    282	return group;
    283}
    284
    285static void tcf_offload_sample_get_group(struct flow_action_entry *entry,
    286					 const struct tc_action *act)
    287{
    288	entry->sample.psample_group =
    289		act->ops->get_psample_group(act, &entry->destructor);
    290	entry->destructor_priv = entry->sample.psample_group;
    291}
    292
    293static int tcf_sample_offload_act_setup(struct tc_action *act, void *entry_data,
    294					u32 *index_inc, bool bind,
    295					struct netlink_ext_ack *extack)
    296{
    297	if (bind) {
    298		struct flow_action_entry *entry = entry_data;
    299
    300		entry->id = FLOW_ACTION_SAMPLE;
    301		entry->sample.trunc_size = tcf_sample_trunc_size(act);
    302		entry->sample.truncate = tcf_sample_truncate(act);
    303		entry->sample.rate = tcf_sample_rate(act);
    304		tcf_offload_sample_get_group(entry, act);
    305		*index_inc = 1;
    306	} else {
    307		struct flow_offload_action *fl_action = entry_data;
    308
    309		fl_action->id = FLOW_ACTION_SAMPLE;
    310	}
    311
    312	return 0;
    313}
    314
    315static struct tc_action_ops act_sample_ops = {
    316	.kind	  = "sample",
    317	.id	  = TCA_ID_SAMPLE,
    318	.owner	  = THIS_MODULE,
    319	.act	  = tcf_sample_act,
    320	.stats_update = tcf_sample_stats_update,
    321	.dump	  = tcf_sample_dump,
    322	.init	  = tcf_sample_init,
    323	.cleanup  = tcf_sample_cleanup,
    324	.walk	  = tcf_sample_walker,
    325	.lookup	  = tcf_sample_search,
    326	.get_psample_group = tcf_sample_get_group,
    327	.offload_act_setup    = tcf_sample_offload_act_setup,
    328	.size	  = sizeof(struct tcf_sample),
    329};
    330
    331static __net_init int sample_init_net(struct net *net)
    332{
    333	struct tc_action_net *tn = net_generic(net, sample_net_id);
    334
    335	return tc_action_net_init(net, tn, &act_sample_ops);
    336}
    337
    338static void __net_exit sample_exit_net(struct list_head *net_list)
    339{
    340	tc_action_net_exit(net_list, sample_net_id);
    341}
    342
    343static struct pernet_operations sample_net_ops = {
    344	.init = sample_init_net,
    345	.exit_batch = sample_exit_net,
    346	.id   = &sample_net_id,
    347	.size = sizeof(struct tc_action_net),
    348};
    349
    350static int __init sample_init_module(void)
    351{
    352	return tcf_register_action(&act_sample_ops, &sample_net_ops);
    353}
    354
    355static void __exit sample_cleanup_module(void)
    356{
    357	tcf_unregister_action(&act_sample_ops, &sample_net_ops);
    358}
    359
    360module_init(sample_init_module);
    361module_exit(sample_cleanup_module);
    362
    363MODULE_AUTHOR("Yotam Gigi <yotam.gi@gmail.com>");
    364MODULE_DESCRIPTION("Packet sampling action");
    365MODULE_LICENSE("GPL v2");