cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

bpf_tcp_ca.c (8216B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Copyright (c) 2019 Facebook */
      3
      4#include <linux/err.h>
      5#include <netinet/tcp.h>
      6#include <test_progs.h>
      7#include "network_helpers.h"
      8#include "bpf_dctcp.skel.h"
      9#include "bpf_cubic.skel.h"
     10#include "bpf_tcp_nogpl.skel.h"
     11#include "bpf_dctcp_release.skel.h"
     12
     13#ifndef ENOTSUPP
     14#define ENOTSUPP 524
     15#endif
     16
     17static const unsigned int total_bytes = 10 * 1024 * 1024;
     18static int expected_stg = 0xeB9F;
     19static int stop, duration;
     20
     21static int settcpca(int fd, const char *tcp_ca)
     22{
     23	int err;
     24
     25	err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca));
     26	if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n",
     27		  errno))
     28		return -1;
     29
     30	return 0;
     31}
     32
     33static void *server(void *arg)
     34{
     35	int lfd = (int)(long)arg, err = 0, fd;
     36	ssize_t nr_sent = 0, bytes = 0;
     37	char batch[1500];
     38
     39	fd = accept(lfd, NULL, NULL);
     40	while (fd == -1) {
     41		if (errno == EINTR)
     42			continue;
     43		err = -errno;
     44		goto done;
     45	}
     46
     47	if (settimeo(fd, 0)) {
     48		err = -errno;
     49		goto done;
     50	}
     51
     52	while (bytes < total_bytes && !READ_ONCE(stop)) {
     53		nr_sent = send(fd, &batch,
     54			       MIN(total_bytes - bytes, sizeof(batch)), 0);
     55		if (nr_sent == -1 && errno == EINTR)
     56			continue;
     57		if (nr_sent == -1) {
     58			err = -errno;
     59			break;
     60		}
     61		bytes += nr_sent;
     62	}
     63
     64	CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n",
     65	      bytes, total_bytes, nr_sent, errno);
     66
     67done:
     68	if (fd >= 0)
     69		close(fd);
     70	if (err) {
     71		WRITE_ONCE(stop, 1);
     72		return ERR_PTR(err);
     73	}
     74	return NULL;
     75}
     76
     77static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map)
     78{
     79	struct sockaddr_in6 sa6 = {};
     80	ssize_t nr_recv = 0, bytes = 0;
     81	int lfd = -1, fd = -1;
     82	pthread_t srv_thread;
     83	socklen_t addrlen = sizeof(sa6);
     84	void *thread_ret;
     85	char batch[1500];
     86	int err;
     87
     88	WRITE_ONCE(stop, 0);
     89
     90	lfd = socket(AF_INET6, SOCK_STREAM, 0);
     91	if (CHECK(lfd == -1, "socket", "errno:%d\n", errno))
     92		return;
     93	fd = socket(AF_INET6, SOCK_STREAM, 0);
     94	if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) {
     95		close(lfd);
     96		return;
     97	}
     98
     99	if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) ||
    100	    settimeo(lfd, 0) || settimeo(fd, 0))
    101		goto done;
    102
    103	/* bind, listen and start server thread to accept */
    104	sa6.sin6_family = AF_INET6;
    105	sa6.sin6_addr = in6addr_loopback;
    106	err = bind(lfd, (struct sockaddr *)&sa6, addrlen);
    107	if (CHECK(err == -1, "bind", "errno:%d\n", errno))
    108		goto done;
    109	err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen);
    110	if (CHECK(err == -1, "getsockname", "errno:%d\n", errno))
    111		goto done;
    112	err = listen(lfd, 1);
    113	if (CHECK(err == -1, "listen", "errno:%d\n", errno))
    114		goto done;
    115
    116	if (sk_stg_map) {
    117		err = bpf_map_update_elem(bpf_map__fd(sk_stg_map), &fd,
    118					  &expected_stg, BPF_NOEXIST);
    119		if (CHECK(err, "bpf_map_update_elem(sk_stg_map)",
    120			  "err:%d errno:%d\n", err, errno))
    121			goto done;
    122	}
    123
    124	/* connect to server */
    125	err = connect(fd, (struct sockaddr *)&sa6, addrlen);
    126	if (CHECK(err == -1, "connect", "errno:%d\n", errno))
    127		goto done;
    128
    129	if (sk_stg_map) {
    130		int tmp_stg;
    131
    132		err = bpf_map_lookup_elem(bpf_map__fd(sk_stg_map), &fd,
    133					  &tmp_stg);
    134		if (CHECK(!err || errno != ENOENT,
    135			  "bpf_map_lookup_elem(sk_stg_map)",
    136			  "err:%d errno:%d\n", err, errno))
    137			goto done;
    138	}
    139
    140	err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd);
    141	if (CHECK(err != 0, "pthread_create", "err:%d errno:%d\n", err, errno))
    142		goto done;
    143
    144	/* recv total_bytes */
    145	while (bytes < total_bytes && !READ_ONCE(stop)) {
    146		nr_recv = recv(fd, &batch,
    147			       MIN(total_bytes - bytes, sizeof(batch)), 0);
    148		if (nr_recv == -1 && errno == EINTR)
    149			continue;
    150		if (nr_recv == -1)
    151			break;
    152		bytes += nr_recv;
    153	}
    154
    155	CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n",
    156	      bytes, total_bytes, nr_recv, errno);
    157
    158	WRITE_ONCE(stop, 1);
    159	pthread_join(srv_thread, &thread_ret);
    160	CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld",
    161	      PTR_ERR(thread_ret));
    162done:
    163	close(lfd);
    164	close(fd);
    165}
    166
    167static void test_cubic(void)
    168{
    169	struct bpf_cubic *cubic_skel;
    170	struct bpf_link *link;
    171
    172	cubic_skel = bpf_cubic__open_and_load();
    173	if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n"))
    174		return;
    175
    176	link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic);
    177	if (!ASSERT_OK_PTR(link, "bpf_map__attach_struct_ops")) {
    178		bpf_cubic__destroy(cubic_skel);
    179		return;
    180	}
    181
    182	do_test("bpf_cubic", NULL);
    183
    184	bpf_link__destroy(link);
    185	bpf_cubic__destroy(cubic_skel);
    186}
    187
    188static void test_dctcp(void)
    189{
    190	struct bpf_dctcp *dctcp_skel;
    191	struct bpf_link *link;
    192
    193	dctcp_skel = bpf_dctcp__open_and_load();
    194	if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n"))
    195		return;
    196
    197	link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
    198	if (!ASSERT_OK_PTR(link, "bpf_map__attach_struct_ops")) {
    199		bpf_dctcp__destroy(dctcp_skel);
    200		return;
    201	}
    202
    203	do_test("bpf_dctcp", dctcp_skel->maps.sk_stg_map);
    204	CHECK(dctcp_skel->bss->stg_result != expected_stg,
    205	      "Unexpected stg_result", "stg_result (%x) != expected_stg (%x)\n",
    206	      dctcp_skel->bss->stg_result, expected_stg);
    207
    208	bpf_link__destroy(link);
    209	bpf_dctcp__destroy(dctcp_skel);
    210}
    211
    212static char *err_str;
    213static bool found;
    214
    215static int libbpf_debug_print(enum libbpf_print_level level,
    216			      const char *format, va_list args)
    217{
    218	const char *prog_name, *log_buf;
    219
    220	if (level != LIBBPF_WARN ||
    221	    !strstr(format, "-- BEGIN PROG LOAD LOG --")) {
    222		vprintf(format, args);
    223		return 0;
    224	}
    225
    226	prog_name = va_arg(args, char *);
    227	log_buf = va_arg(args, char *);
    228	if (!log_buf)
    229		goto out;
    230	if (err_str && strstr(log_buf, err_str) != NULL)
    231		found = true;
    232out:
    233	printf(format, prog_name, log_buf);
    234	return 0;
    235}
    236
    237static void test_invalid_license(void)
    238{
    239	libbpf_print_fn_t old_print_fn;
    240	struct bpf_tcp_nogpl *skel;
    241
    242	err_str = "struct ops programs must have a GPL compatible license";
    243	found = false;
    244	old_print_fn = libbpf_set_print(libbpf_debug_print);
    245
    246	skel = bpf_tcp_nogpl__open_and_load();
    247	ASSERT_NULL(skel, "bpf_tcp_nogpl");
    248	ASSERT_EQ(found, true, "expected_err_msg");
    249
    250	bpf_tcp_nogpl__destroy(skel);
    251	libbpf_set_print(old_print_fn);
    252}
    253
    254static void test_dctcp_fallback(void)
    255{
    256	int err, lfd = -1, cli_fd = -1, srv_fd = -1;
    257	struct network_helper_opts opts = {
    258		.cc = "cubic",
    259	};
    260	struct bpf_dctcp *dctcp_skel;
    261	struct bpf_link *link = NULL;
    262	char srv_cc[16];
    263	socklen_t cc_len = sizeof(srv_cc);
    264
    265	dctcp_skel = bpf_dctcp__open();
    266	if (!ASSERT_OK_PTR(dctcp_skel, "dctcp_skel"))
    267		return;
    268	strcpy(dctcp_skel->rodata->fallback, "cubic");
    269	if (!ASSERT_OK(bpf_dctcp__load(dctcp_skel), "bpf_dctcp__load"))
    270		goto done;
    271
    272	link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
    273	if (!ASSERT_OK_PTR(link, "dctcp link"))
    274		goto done;
    275
    276	lfd = start_server(AF_INET6, SOCK_STREAM, "::1", 0, 0);
    277	if (!ASSERT_GE(lfd, 0, "lfd") ||
    278	    !ASSERT_OK(settcpca(lfd, "bpf_dctcp"), "lfd=>bpf_dctcp"))
    279		goto done;
    280
    281	cli_fd = connect_to_fd_opts(lfd, &opts);
    282	if (!ASSERT_GE(cli_fd, 0, "cli_fd"))
    283		goto done;
    284
    285	srv_fd = accept(lfd, NULL, 0);
    286	if (!ASSERT_GE(srv_fd, 0, "srv_fd"))
    287		goto done;
    288	ASSERT_STREQ(dctcp_skel->bss->cc_res, "cubic", "cc_res");
    289	ASSERT_EQ(dctcp_skel->bss->tcp_cdg_res, -ENOTSUPP, "tcp_cdg_res");
    290
    291	err = getsockopt(srv_fd, SOL_TCP, TCP_CONGESTION, srv_cc, &cc_len);
    292	if (!ASSERT_OK(err, "getsockopt(srv_fd, TCP_CONGESTION)"))
    293		goto done;
    294	ASSERT_STREQ(srv_cc, "cubic", "srv_fd cc");
    295
    296done:
    297	bpf_link__destroy(link);
    298	bpf_dctcp__destroy(dctcp_skel);
    299	if (lfd != -1)
    300		close(lfd);
    301	if (srv_fd != -1)
    302		close(srv_fd);
    303	if (cli_fd != -1)
    304		close(cli_fd);
    305}
    306
    307static void test_rel_setsockopt(void)
    308{
    309	struct bpf_dctcp_release *rel_skel;
    310	libbpf_print_fn_t old_print_fn;
    311
    312	err_str = "unknown func bpf_setsockopt";
    313	found = false;
    314
    315	old_print_fn = libbpf_set_print(libbpf_debug_print);
    316	rel_skel = bpf_dctcp_release__open_and_load();
    317	libbpf_set_print(old_print_fn);
    318
    319	ASSERT_ERR_PTR(rel_skel, "rel_skel");
    320	ASSERT_TRUE(found, "expected_err_msg");
    321
    322	bpf_dctcp_release__destroy(rel_skel);
    323}
    324
    325void test_bpf_tcp_ca(void)
    326{
    327	if (test__start_subtest("dctcp"))
    328		test_dctcp();
    329	if (test__start_subtest("cubic"))
    330		test_cubic();
    331	if (test__start_subtest("invalid_license"))
    332		test_invalid_license();
    333	if (test__start_subtest("dctcp_fallback"))
    334		test_dctcp_fallback();
    335	if (test__start_subtest("rel_setsockopt"))
    336		test_rel_setsockopt();
    337}