cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sockopt_inherit.c (5314B)


      1// SPDX-License-Identifier: GPL-2.0
      2#include <test_progs.h>
      3#include "cgroup_helpers.h"
      4
      5#define SOL_CUSTOM			0xdeadbeef
      6#define CUSTOM_INHERIT1			0
      7#define CUSTOM_INHERIT2			1
      8#define CUSTOM_LISTENER			2
      9
     10static int connect_to_server(int server_fd)
     11{
     12	struct sockaddr_storage addr;
     13	socklen_t len = sizeof(addr);
     14	int fd;
     15
     16	fd = socket(AF_INET, SOCK_STREAM, 0);
     17	if (fd < 0) {
     18		log_err("Failed to create client socket");
     19		return -1;
     20	}
     21
     22	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
     23		log_err("Failed to get server addr");
     24		goto out;
     25	}
     26
     27	if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
     28		log_err("Fail to connect to server");
     29		goto out;
     30	}
     31
     32	return fd;
     33
     34out:
     35	close(fd);
     36	return -1;
     37}
     38
     39static int verify_sockopt(int fd, int optname, const char *msg, char expected)
     40{
     41	socklen_t optlen = 1;
     42	char buf = 0;
     43	int err;
     44
     45	err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
     46	if (err) {
     47		log_err("%s: failed to call getsockopt", msg);
     48		return 1;
     49	}
     50
     51	printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
     52
     53	if (buf != expected) {
     54		log_err("%s: unexpected getsockopt value %d != %d", msg,
     55			buf, expected);
     56		return 1;
     57	}
     58
     59	return 0;
     60}
     61
     62static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
     63static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
     64
     65static void *server_thread(void *arg)
     66{
     67	struct sockaddr_storage addr;
     68	socklen_t len = sizeof(addr);
     69	int fd = *(int *)arg;
     70	int client_fd;
     71	int err = 0;
     72
     73	err = listen(fd, 1);
     74
     75	pthread_mutex_lock(&server_started_mtx);
     76	pthread_cond_signal(&server_started);
     77	pthread_mutex_unlock(&server_started_mtx);
     78
     79	if (CHECK_FAIL(err < 0)) {
     80		perror("Failed to listed on socket");
     81		return NULL;
     82	}
     83
     84	err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
     85	err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
     86	err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
     87
     88	client_fd = accept(fd, (struct sockaddr *)&addr, &len);
     89	if (CHECK_FAIL(client_fd < 0)) {
     90		perror("Failed to accept client");
     91		return NULL;
     92	}
     93
     94	err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
     95	err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
     96	err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
     97
     98	close(client_fd);
     99
    100	return (void *)(long)err;
    101}
    102
    103static int start_server(void)
    104{
    105	struct sockaddr_in addr = {
    106		.sin_family = AF_INET,
    107		.sin_addr.s_addr = htonl(INADDR_LOOPBACK),
    108	};
    109	char buf;
    110	int err;
    111	int fd;
    112	int i;
    113
    114	fd = socket(AF_INET, SOCK_STREAM, 0);
    115	if (fd < 0) {
    116		log_err("Failed to create server socket");
    117		return -1;
    118	}
    119
    120	for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
    121		buf = 0x01;
    122		err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
    123		if (err) {
    124			log_err("Failed to call setsockopt(%d)", i);
    125			close(fd);
    126			return -1;
    127		}
    128	}
    129
    130	if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
    131		log_err("Failed to bind socket");
    132		close(fd);
    133		return -1;
    134	}
    135
    136	return fd;
    137}
    138
    139static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title,
    140		       const char *prog_name)
    141{
    142	enum bpf_attach_type attach_type;
    143	enum bpf_prog_type prog_type;
    144	struct bpf_program *prog;
    145	int err;
    146
    147	err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
    148	if (err) {
    149		log_err("Failed to deduct types for %s BPF program", prog_name);
    150		return -1;
    151	}
    152
    153	prog = bpf_object__find_program_by_name(obj, prog_name);
    154	if (!prog) {
    155		log_err("Failed to find %s BPF program", prog_name);
    156		return -1;
    157	}
    158
    159	err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
    160			      attach_type, 0);
    161	if (err) {
    162		log_err("Failed to attach %s BPF program", prog_name);
    163		return -1;
    164	}
    165
    166	return 0;
    167}
    168
    169static void run_test(int cgroup_fd)
    170{
    171	int server_fd = -1, client_fd;
    172	struct bpf_object *obj;
    173	void *server_err;
    174	pthread_t tid;
    175	int err;
    176
    177	obj = bpf_object__open_file("sockopt_inherit.o", NULL);
    178	if (!ASSERT_OK_PTR(obj, "obj_open"))
    179		return;
    180
    181	err = bpf_object__load(obj);
    182	if (!ASSERT_OK(err, "obj_load"))
    183		goto close_bpf_object;
    184
    185	err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt", "_getsockopt");
    186	if (CHECK_FAIL(err))
    187		goto close_bpf_object;
    188
    189	err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt", "_setsockopt");
    190	if (CHECK_FAIL(err))
    191		goto close_bpf_object;
    192
    193	server_fd = start_server();
    194	if (CHECK_FAIL(server_fd < 0))
    195		goto close_bpf_object;
    196
    197	pthread_mutex_lock(&server_started_mtx);
    198	if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread,
    199				      (void *)&server_fd))) {
    200		pthread_mutex_unlock(&server_started_mtx);
    201		goto close_server_fd;
    202	}
    203	pthread_cond_wait(&server_started, &server_started_mtx);
    204	pthread_mutex_unlock(&server_started_mtx);
    205
    206	client_fd = connect_to_server(server_fd);
    207	if (CHECK_FAIL(client_fd < 0))
    208		goto close_server_fd;
    209
    210	CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0));
    211	CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0));
    212	CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0));
    213
    214	pthread_join(tid, &server_err);
    215
    216	err = (int)(long)server_err;
    217	CHECK_FAIL(err);
    218
    219	close(client_fd);
    220
    221close_server_fd:
    222	close(server_fd);
    223close_bpf_object:
    224	bpf_object__close(obj);
    225}
    226
    227void test_sockopt_inherit(void)
    228{
    229	int cgroup_fd;
    230
    231	cgroup_fd = test__join_cgroup("/sockopt_inherit");
    232	if (CHECK_FAIL(cgroup_fd < 0))
    233		return;
    234
    235	run_test(cgroup_fd);
    236	close(cgroup_fd);
    237}