cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mptcp.c (3718B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Copyright (c) 2020, Tessares SA. */
      3/* Copyright (c) 2022, SUSE. */
      4
      5#include <test_progs.h>
      6#include "cgroup_helpers.h"
      7#include "network_helpers.h"
      8#include "mptcp_sock.skel.h"
      9
     10#ifndef TCP_CA_NAME_MAX
     11#define TCP_CA_NAME_MAX	16
     12#endif
     13
     14struct mptcp_storage {
     15	__u32 invoked;
     16	__u32 is_mptcp;
     17	struct sock *sk;
     18	__u32 token;
     19	struct sock *first;
     20	char ca_name[TCP_CA_NAME_MAX];
     21};
     22
     23static int verify_tsk(int map_fd, int client_fd)
     24{
     25	int err, cfd = client_fd;
     26	struct mptcp_storage val;
     27
     28	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
     29	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
     30		return err;
     31
     32	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
     33		err++;
     34
     35	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
     36		err++;
     37
     38	return err;
     39}
     40
     41static void get_msk_ca_name(char ca_name[])
     42{
     43	size_t len;
     44	int fd;
     45
     46	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
     47	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
     48		return;
     49
     50	len = read(fd, ca_name, TCP_CA_NAME_MAX);
     51	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
     52		goto err;
     53
     54	if (len > 0 && ca_name[len - 1] == '\n')
     55		ca_name[len - 1] = '\0';
     56
     57err:
     58	close(fd);
     59}
     60
     61static int verify_msk(int map_fd, int client_fd, __u32 token)
     62{
     63	char ca_name[TCP_CA_NAME_MAX];
     64	int err, cfd = client_fd;
     65	struct mptcp_storage val;
     66
     67	if (!ASSERT_GT(token, 0, "invalid token"))
     68		return -1;
     69
     70	get_msk_ca_name(ca_name);
     71
     72	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
     73	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
     74		return err;
     75
     76	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
     77		err++;
     78
     79	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
     80		err++;
     81
     82	if (!ASSERT_EQ(val.token, token, "unexpected token"))
     83		err++;
     84
     85	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
     86		err++;
     87
     88	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
     89		err++;
     90
     91	return err;
     92}
     93
     94static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
     95{
     96	int client_fd, prog_fd, map_fd, err;
     97	struct mptcp_sock *sock_skel;
     98
     99	sock_skel = mptcp_sock__open_and_load();
    100	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
    101		return -EIO;
    102
    103	err = mptcp_sock__attach(sock_skel);
    104	if (!ASSERT_OK(err, "skel_attach"))
    105		goto out;
    106
    107	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
    108	if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
    109		err = -EIO;
    110		goto out;
    111	}
    112
    113	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
    114	if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
    115		err = -EIO;
    116		goto out;
    117	}
    118
    119	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
    120	if (!ASSERT_OK(err, "bpf_prog_attach"))
    121		goto out;
    122
    123	client_fd = connect_to_fd(server_fd, 0);
    124	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
    125		err = -EIO;
    126		goto out;
    127	}
    128
    129	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
    130			  verify_tsk(map_fd, client_fd);
    131
    132	close(client_fd);
    133
    134out:
    135	mptcp_sock__destroy(sock_skel);
    136	return err;
    137}
    138
    139static void test_base(void)
    140{
    141	int server_fd, cgroup_fd;
    142
    143	cgroup_fd = test__join_cgroup("/mptcp");
    144	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
    145		return;
    146
    147	/* without MPTCP */
    148	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
    149	if (!ASSERT_GE(server_fd, 0, "start_server"))
    150		goto with_mptcp;
    151
    152	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
    153
    154	close(server_fd);
    155
    156with_mptcp:
    157	/* with MPTCP */
    158	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
    159	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
    160		goto close_cgroup_fd;
    161
    162	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
    163
    164	close(server_fd);
    165
    166close_cgroup_fd:
    167	close(cgroup_fd);
    168}
    169
    170void test_mptcp(void)
    171{
    172	if (test__start_subtest("base"))
    173		test_base();
    174}