cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

comm.c (2888B)


      1// SPDX-License-Identifier: GPL-2.0
      2#include "comm.h"
      3#include <errno.h>
      4#include <stdlib.h>
      5#include <stdio.h>
      6#include <string.h>
      7#include <linux/refcount.h>
      8#include <linux/rbtree.h>
      9#include <linux/zalloc.h>
     10#include "rwsem.h"
     11
     12struct comm_str {
     13	char *str;
     14	struct rb_node rb_node;
     15	refcount_t refcnt;
     16};
     17
     18/* Should perhaps be moved to struct machine */
     19static struct rb_root comm_str_root;
     20static struct rw_semaphore comm_str_lock = {.lock = PTHREAD_RWLOCK_INITIALIZER,};
     21
     22static struct comm_str *comm_str__get(struct comm_str *cs)
     23{
     24	if (cs && refcount_inc_not_zero(&cs->refcnt))
     25		return cs;
     26
     27	return NULL;
     28}
     29
     30static void comm_str__put(struct comm_str *cs)
     31{
     32	if (cs && refcount_dec_and_test(&cs->refcnt)) {
     33		down_write(&comm_str_lock);
     34		rb_erase(&cs->rb_node, &comm_str_root);
     35		up_write(&comm_str_lock);
     36		zfree(&cs->str);
     37		free(cs);
     38	}
     39}
     40
     41static struct comm_str *comm_str__alloc(const char *str)
     42{
     43	struct comm_str *cs;
     44
     45	cs = zalloc(sizeof(*cs));
     46	if (!cs)
     47		return NULL;
     48
     49	cs->str = strdup(str);
     50	if (!cs->str) {
     51		free(cs);
     52		return NULL;
     53	}
     54
     55	refcount_set(&cs->refcnt, 1);
     56
     57	return cs;
     58}
     59
     60static
     61struct comm_str *__comm_str__findnew(const char *str, struct rb_root *root)
     62{
     63	struct rb_node **p = &root->rb_node;
     64	struct rb_node *parent = NULL;
     65	struct comm_str *iter, *new;
     66	int cmp;
     67
     68	while (*p != NULL) {
     69		parent = *p;
     70		iter = rb_entry(parent, struct comm_str, rb_node);
     71
     72		/*
     73		 * If we race with comm_str__put, iter->refcnt is 0
     74		 * and it will be removed within comm_str__put call
     75		 * shortly, ignore it in this search.
     76		 */
     77		cmp = strcmp(str, iter->str);
     78		if (!cmp && comm_str__get(iter))
     79			return iter;
     80
     81		if (cmp < 0)
     82			p = &(*p)->rb_left;
     83		else
     84			p = &(*p)->rb_right;
     85	}
     86
     87	new = comm_str__alloc(str);
     88	if (!new)
     89		return NULL;
     90
     91	rb_link_node(&new->rb_node, parent, p);
     92	rb_insert_color(&new->rb_node, root);
     93
     94	return new;
     95}
     96
     97static struct comm_str *comm_str__findnew(const char *str, struct rb_root *root)
     98{
     99	struct comm_str *cs;
    100
    101	down_write(&comm_str_lock);
    102	cs = __comm_str__findnew(str, root);
    103	up_write(&comm_str_lock);
    104
    105	return cs;
    106}
    107
    108struct comm *comm__new(const char *str, u64 timestamp, bool exec)
    109{
    110	struct comm *comm = zalloc(sizeof(*comm));
    111
    112	if (!comm)
    113		return NULL;
    114
    115	comm->start = timestamp;
    116	comm->exec = exec;
    117
    118	comm->comm_str = comm_str__findnew(str, &comm_str_root);
    119	if (!comm->comm_str) {
    120		free(comm);
    121		return NULL;
    122	}
    123
    124	return comm;
    125}
    126
    127int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
    128{
    129	struct comm_str *new, *old = comm->comm_str;
    130
    131	new = comm_str__findnew(str, &comm_str_root);
    132	if (!new)
    133		return -ENOMEM;
    134
    135	comm_str__put(old);
    136	comm->comm_str = new;
    137	comm->start = timestamp;
    138	if (exec)
    139		comm->exec = true;
    140
    141	return 0;
    142}
    143
    144void comm__free(struct comm *comm)
    145{
    146	comm_str__put(comm->comm_str);
    147	free(comm);
    148}
    149
    150const char *comm__str(const struct comm *comm)
    151{
    152	return comm->comm_str->str;
    153}