cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

connect4_prog.c (4539B)


      1// SPDX-License-Identifier: GPL-2.0
      2// Copyright (c) 2018 Facebook
      3
      4#include <string.h>
      5
      6#include <linux/stddef.h>
      7#include <linux/bpf.h>
      8#include <linux/in.h>
      9#include <linux/in6.h>
     10#include <sys/socket.h>
     11#include <netinet/tcp.h>
     12#include <linux/if.h>
     13#include <errno.h>
     14
     15#include <bpf/bpf_helpers.h>
     16#include <bpf/bpf_endian.h>
     17
     18#define SRC_REWRITE_IP4		0x7f000004U
     19#define DST_REWRITE_IP4		0x7f000001U
     20#define DST_REWRITE_PORT4	4444
     21
     22#ifndef TCP_CA_NAME_MAX
     23#define TCP_CA_NAME_MAX 16
     24#endif
     25
     26#ifndef TCP_NOTSENT_LOWAT
     27#define TCP_NOTSENT_LOWAT 25
     28#endif
     29
     30#ifndef IFNAMSIZ
     31#define IFNAMSIZ 16
     32#endif
     33
     34__attribute__ ((noinline))
     35int do_bind(struct bpf_sock_addr *ctx)
     36{
     37	struct sockaddr_in sa = {};
     38
     39	sa.sin_family = AF_INET;
     40	sa.sin_port = bpf_htons(0);
     41	sa.sin_addr.s_addr = bpf_htonl(SRC_REWRITE_IP4);
     42
     43	if (bpf_bind(ctx, (struct sockaddr *)&sa, sizeof(sa)) != 0)
     44		return 0;
     45
     46	return 1;
     47}
     48
     49static __inline int verify_cc(struct bpf_sock_addr *ctx,
     50			      char expected[TCP_CA_NAME_MAX])
     51{
     52	char buf[TCP_CA_NAME_MAX];
     53	int i;
     54
     55	if (bpf_getsockopt(ctx, SOL_TCP, TCP_CONGESTION, &buf, sizeof(buf)))
     56		return 1;
     57
     58	for (i = 0; i < TCP_CA_NAME_MAX; i++) {
     59		if (buf[i] != expected[i])
     60			return 1;
     61		if (buf[i] == 0)
     62			break;
     63	}
     64
     65	return 0;
     66}
     67
     68static __inline int set_cc(struct bpf_sock_addr *ctx)
     69{
     70	char reno[TCP_CA_NAME_MAX] = "reno";
     71	char cubic[TCP_CA_NAME_MAX] = "cubic";
     72
     73	if (bpf_setsockopt(ctx, SOL_TCP, TCP_CONGESTION, &reno, sizeof(reno)))
     74		return 1;
     75	if (verify_cc(ctx, reno))
     76		return 1;
     77
     78	if (bpf_setsockopt(ctx, SOL_TCP, TCP_CONGESTION, &cubic, sizeof(cubic)))
     79		return 1;
     80	if (verify_cc(ctx, cubic))
     81		return 1;
     82
     83	return 0;
     84}
     85
     86static __inline int bind_to_device(struct bpf_sock_addr *ctx)
     87{
     88	char veth1[IFNAMSIZ] = "test_sock_addr1";
     89	char veth2[IFNAMSIZ] = "test_sock_addr2";
     90	char missing[IFNAMSIZ] = "nonexistent_dev";
     91	char del_bind[IFNAMSIZ] = "";
     92
     93	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
     94				&veth1, sizeof(veth1)))
     95		return 1;
     96	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
     97				&veth2, sizeof(veth2)))
     98		return 1;
     99	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
    100				&missing, sizeof(missing)) != -ENODEV)
    101		return 1;
    102	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
    103				&del_bind, sizeof(del_bind)))
    104		return 1;
    105
    106	return 0;
    107}
    108
    109static __inline int set_keepalive(struct bpf_sock_addr *ctx)
    110{
    111	int zero = 0, one = 1;
    112
    113	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)))
    114		return 1;
    115	if (ctx->type == SOCK_STREAM) {
    116		if (bpf_setsockopt(ctx, SOL_TCP, TCP_KEEPIDLE, &one, sizeof(one)))
    117			return 1;
    118		if (bpf_setsockopt(ctx, SOL_TCP, TCP_KEEPINTVL, &one, sizeof(one)))
    119			return 1;
    120		if (bpf_setsockopt(ctx, SOL_TCP, TCP_KEEPCNT, &one, sizeof(one)))
    121			return 1;
    122		if (bpf_setsockopt(ctx, SOL_TCP, TCP_SYNCNT, &one, sizeof(one)))
    123			return 1;
    124		if (bpf_setsockopt(ctx, SOL_TCP, TCP_USER_TIMEOUT, &one, sizeof(one)))
    125			return 1;
    126	}
    127	if (bpf_setsockopt(ctx, SOL_SOCKET, SO_KEEPALIVE, &zero, sizeof(zero)))
    128		return 1;
    129
    130	return 0;
    131}
    132
    133static __inline int set_notsent_lowat(struct bpf_sock_addr *ctx)
    134{
    135	int lowat = 65535;
    136
    137	if (ctx->type == SOCK_STREAM) {
    138		if (bpf_setsockopt(ctx, SOL_TCP, TCP_NOTSENT_LOWAT, &lowat, sizeof(lowat)))
    139			return 1;
    140	}
    141
    142	return 0;
    143}
    144
    145SEC("cgroup/connect4")
    146int connect_v4_prog(struct bpf_sock_addr *ctx)
    147{
    148	struct bpf_sock_tuple tuple = {};
    149	struct bpf_sock *sk;
    150
    151	/* Verify that new destination is available. */
    152	memset(&tuple.ipv4.saddr, 0, sizeof(tuple.ipv4.saddr));
    153	memset(&tuple.ipv4.sport, 0, sizeof(tuple.ipv4.sport));
    154
    155	tuple.ipv4.daddr = bpf_htonl(DST_REWRITE_IP4);
    156	tuple.ipv4.dport = bpf_htons(DST_REWRITE_PORT4);
    157
    158	/* Bind to device and unbind it. */
    159	if (bind_to_device(ctx))
    160		return 0;
    161
    162	if (set_keepalive(ctx))
    163		return 0;
    164
    165	if (set_notsent_lowat(ctx))
    166		return 0;
    167
    168	if (ctx->type != SOCK_STREAM && ctx->type != SOCK_DGRAM)
    169		return 0;
    170	else if (ctx->type == SOCK_STREAM)
    171		sk = bpf_sk_lookup_tcp(ctx, &tuple, sizeof(tuple.ipv4),
    172				       BPF_F_CURRENT_NETNS, 0);
    173	else
    174		sk = bpf_sk_lookup_udp(ctx, &tuple, sizeof(tuple.ipv4),
    175				       BPF_F_CURRENT_NETNS, 0);
    176
    177	if (!sk)
    178		return 0;
    179
    180	if (sk->src_ip4 != tuple.ipv4.daddr ||
    181	    sk->src_port != DST_REWRITE_PORT4) {
    182		bpf_sk_release(sk);
    183		return 0;
    184	}
    185
    186	bpf_sk_release(sk);
    187
    188	/* Rewrite congestion control. */
    189	if (ctx->type == SOCK_STREAM && set_cc(ctx))
    190		return 0;
    191
    192	/* Rewrite destination. */
    193	ctx->user_ip4 = bpf_htonl(DST_REWRITE_IP4);
    194	ctx->user_port = bpf_htons(DST_REWRITE_PORT4);
    195
    196	return do_bind(ctx) ? 1 : 0;
    197}
    198
    199char _license[] SEC("license") = "GPL";