cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

ipsec.c (57370B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * ipsec.c - Check xfrm on veth inside a net-ns.
      4 * Copyright (c) 2018 Dmitry Safonov
      5 */
      6
      7#define _GNU_SOURCE
      8
      9#include <arpa/inet.h>
     10#include <asm/types.h>
     11#include <errno.h>
     12#include <fcntl.h>
     13#include <limits.h>
     14#include <linux/limits.h>
     15#include <linux/netlink.h>
     16#include <linux/random.h>
     17#include <linux/rtnetlink.h>
     18#include <linux/veth.h>
     19#include <linux/xfrm.h>
     20#include <netinet/in.h>
     21#include <net/if.h>
     22#include <sched.h>
     23#include <stdbool.h>
     24#include <stdint.h>
     25#include <stdio.h>
     26#include <stdlib.h>
     27#include <string.h>
     28#include <sys/mman.h>
     29#include <sys/socket.h>
     30#include <sys/stat.h>
     31#include <sys/syscall.h>
     32#include <sys/types.h>
     33#include <sys/wait.h>
     34#include <time.h>
     35#include <unistd.h>
     36
     37#include "../kselftest.h"
     38
     39#define printk(fmt, ...)						\
     40	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
     41
     42#define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
     43
     44#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
     45
     46#define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
     47#define MAX_PAYLOAD	2048
     48#define XFRM_ALGO_KEY_BUF_SIZE	512
     49#define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
     50#define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
     51#define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
     52
     53/* /30 mask for one veth connection */
     54#define PREFIX_LEN	30
     55#define child_ip(nr)	(4*nr + 1)
     56#define grchild_ip(nr)	(4*nr + 2)
     57
     58#define VETH_FMT	"ktst-%d"
     59#define VETH_LEN	12
     60
     61static int nsfd_parent	= -1;
     62static int nsfd_childa	= -1;
     63static int nsfd_childb	= -1;
     64static long page_size;
     65
     66/*
     67 * ksft_cnt is static in kselftest, so isn't shared with children.
     68 * We have to send a test result back to parent and count there.
     69 * results_fd is a pipe with test feedback from children.
     70 */
     71static int results_fd[2];
     72
     73const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
     74const unsigned int ping_timeout		= 300;
     75const unsigned int ping_count		= 100;
     76const unsigned int ping_success		= 80;
     77
     78static void randomize_buffer(void *buf, size_t buflen)
     79{
     80	int *p = (int *)buf;
     81	size_t words = buflen / sizeof(int);
     82	size_t leftover = buflen % sizeof(int);
     83
     84	if (!buflen)
     85		return;
     86
     87	while (words--)
     88		*p++ = rand();
     89
     90	if (leftover) {
     91		int tmp = rand();
     92
     93		memcpy(buf + buflen - leftover, &tmp, leftover);
     94	}
     95
     96	return;
     97}
     98
     99static int unshare_open(void)
    100{
    101	const char *netns_path = "/proc/self/ns/net";
    102	int fd;
    103
    104	if (unshare(CLONE_NEWNET) != 0) {
    105		pr_err("unshare()");
    106		return -1;
    107	}
    108
    109	fd = open(netns_path, O_RDONLY);
    110	if (fd <= 0) {
    111		pr_err("open(%s)", netns_path);
    112		return -1;
    113	}
    114
    115	return fd;
    116}
    117
    118static int switch_ns(int fd)
    119{
    120	if (setns(fd, CLONE_NEWNET)) {
    121		pr_err("setns()");
    122		return -1;
    123	}
    124	return 0;
    125}
    126
    127/*
    128 * Running the test inside a new parent net namespace to bother less
    129 * about cleanup on error-path.
    130 */
    131static int init_namespaces(void)
    132{
    133	nsfd_parent = unshare_open();
    134	if (nsfd_parent <= 0)
    135		return -1;
    136
    137	nsfd_childa = unshare_open();
    138	if (nsfd_childa <= 0)
    139		return -1;
    140
    141	if (switch_ns(nsfd_parent))
    142		return -1;
    143
    144	nsfd_childb = unshare_open();
    145	if (nsfd_childb <= 0)
    146		return -1;
    147
    148	if (switch_ns(nsfd_parent))
    149		return -1;
    150	return 0;
    151}
    152
    153static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
    154{
    155	if (*sock > 0) {
    156		seq_nr++;
    157		return 0;
    158	}
    159
    160	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
    161	if (*sock <= 0) {
    162		pr_err("socket(AF_NETLINK)");
    163		return -1;
    164	}
    165
    166	randomize_buffer(seq_nr, sizeof(*seq_nr));
    167
    168	return 0;
    169}
    170
    171static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
    172{
    173	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
    174}
    175
    176static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
    177		unsigned short rta_type, const void *payload, size_t size)
    178{
    179	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
    180	struct rtattr *attr = rtattr_hdr(nh);
    181	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
    182
    183	if (req_sz < nl_size) {
    184		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
    185		return -1;
    186	}
    187	nh->nlmsg_len = nl_size;
    188
    189	attr->rta_len = RTA_LENGTH(size);
    190	attr->rta_type = rta_type;
    191	memcpy(RTA_DATA(attr), payload, size);
    192
    193	return 0;
    194}
    195
    196static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
    197		unsigned short rta_type, const void *payload, size_t size)
    198{
    199	struct rtattr *ret = rtattr_hdr(nh);
    200
    201	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
    202		return 0;
    203
    204	return ret;
    205}
    206
    207static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
    208		unsigned short rta_type)
    209{
    210	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
    211}
    212
    213static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
    214{
    215	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
    216
    217	attr->rta_len = nlmsg_end - (char *)attr;
    218}
    219
    220static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
    221		const char *peer, int ns)
    222{
    223	struct ifinfomsg pi;
    224	struct rtattr *peer_attr;
    225
    226	memset(&pi, 0, sizeof(pi));
    227	pi.ifi_family	= AF_UNSPEC;
    228	pi.ifi_change	= 0xFFFFFFFF;
    229
    230	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
    231	if (!peer_attr)
    232		return -1;
    233
    234	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
    235		return -1;
    236
    237	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
    238		return -1;
    239
    240	rtattr_end(nh, peer_attr);
    241
    242	return 0;
    243}
    244
    245static int netlink_check_answer(int sock)
    246{
    247	struct nlmsgerror {
    248		struct nlmsghdr hdr;
    249		int error;
    250		struct nlmsghdr orig_msg;
    251	} answer;
    252
    253	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
    254		pr_err("recv()");
    255		return -1;
    256	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
    257		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
    258		return -1;
    259	} else if (answer.error) {
    260		printk("NLMSG_ERROR: %d: %s",
    261			answer.error, strerror(-answer.error));
    262		return answer.error;
    263	}
    264
    265	return 0;
    266}
    267
    268static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
    269		const char *peerb, int ns_b)
    270{
    271	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
    272	struct {
    273		struct nlmsghdr		nh;
    274		struct ifinfomsg	info;
    275		char			attrbuf[MAX_PAYLOAD];
    276	} req;
    277	const char veth_type[] = "veth";
    278	struct rtattr *link_info, *info_data;
    279
    280	memset(&req, 0, sizeof(req));
    281	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
    282	req.nh.nlmsg_type	= RTM_NEWLINK;
    283	req.nh.nlmsg_flags	= flags;
    284	req.nh.nlmsg_seq	= seq;
    285	req.info.ifi_family	= AF_UNSPEC;
    286	req.info.ifi_change	= 0xFFFFFFFF;
    287
    288	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
    289		return -1;
    290
    291	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
    292		return -1;
    293
    294	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
    295	if (!link_info)
    296		return -1;
    297
    298	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
    299		return -1;
    300
    301	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
    302	if (!info_data)
    303		return -1;
    304
    305	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
    306		return -1;
    307
    308	rtattr_end(&req.nh, info_data);
    309	rtattr_end(&req.nh, link_info);
    310
    311	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
    312		pr_err("send()");
    313		return -1;
    314	}
    315	return netlink_check_answer(sock);
    316}
    317
    318static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
    319		struct in_addr addr, uint8_t prefix)
    320{
    321	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
    322	struct {
    323		struct nlmsghdr		nh;
    324		struct ifaddrmsg	info;
    325		char			attrbuf[MAX_PAYLOAD];
    326	} req;
    327
    328	memset(&req, 0, sizeof(req));
    329	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
    330	req.nh.nlmsg_type	= RTM_NEWADDR;
    331	req.nh.nlmsg_flags	= flags;
    332	req.nh.nlmsg_seq	= seq;
    333	req.info.ifa_family	= AF_INET;
    334	req.info.ifa_prefixlen	= prefix;
    335	req.info.ifa_index	= if_nametoindex(intf);
    336
    337#ifdef DEBUG
    338	{
    339		char addr_str[IPV4_STR_SZ] = {};
    340
    341		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
    342
    343		printk("ip addr set %s", addr_str);
    344	}
    345#endif
    346
    347	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
    348		return -1;
    349
    350	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
    351		return -1;
    352
    353	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
    354		pr_err("send()");
    355		return -1;
    356	}
    357	return netlink_check_answer(sock);
    358}
    359
    360static int link_set_up(int sock, uint32_t seq, const char *intf)
    361{
    362	struct {
    363		struct nlmsghdr		nh;
    364		struct ifinfomsg	info;
    365		char			attrbuf[MAX_PAYLOAD];
    366	} req;
    367
    368	memset(&req, 0, sizeof(req));
    369	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
    370	req.nh.nlmsg_type	= RTM_NEWLINK;
    371	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
    372	req.nh.nlmsg_seq	= seq;
    373	req.info.ifi_family	= AF_UNSPEC;
    374	req.info.ifi_change	= 0xFFFFFFFF;
    375	req.info.ifi_index	= if_nametoindex(intf);
    376	req.info.ifi_flags	= IFF_UP;
    377	req.info.ifi_change	= IFF_UP;
    378
    379	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
    380		pr_err("send()");
    381		return -1;
    382	}
    383	return netlink_check_answer(sock);
    384}
    385
    386static int ip4_route_set(int sock, uint32_t seq, const char *intf,
    387		struct in_addr src, struct in_addr dst)
    388{
    389	struct {
    390		struct nlmsghdr	nh;
    391		struct rtmsg	rt;
    392		char		attrbuf[MAX_PAYLOAD];
    393	} req;
    394	unsigned int index = if_nametoindex(intf);
    395
    396	memset(&req, 0, sizeof(req));
    397	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
    398	req.nh.nlmsg_type	= RTM_NEWROUTE;
    399	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
    400	req.nh.nlmsg_seq	= seq;
    401	req.rt.rtm_family	= AF_INET;
    402	req.rt.rtm_dst_len	= 32;
    403	req.rt.rtm_table	= RT_TABLE_MAIN;
    404	req.rt.rtm_protocol	= RTPROT_BOOT;
    405	req.rt.rtm_scope	= RT_SCOPE_LINK;
    406	req.rt.rtm_type		= RTN_UNICAST;
    407
    408	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
    409		return -1;
    410
    411	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
    412		return -1;
    413
    414	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
    415		return -1;
    416
    417	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
    418		pr_err("send()");
    419		return -1;
    420	}
    421
    422	return netlink_check_answer(sock);
    423}
    424
    425static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
    426		struct in_addr tunsrc, struct in_addr tundst)
    427{
    428	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
    429			tunsrc, PREFIX_LEN)) {
    430		printk("Failed to set ipv4 addr");
    431		return -1;
    432	}
    433
    434	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
    435		printk("Failed to set ipv4 route");
    436		return -1;
    437	}
    438
    439	return 0;
    440}
    441
    442static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
    443{
    444	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
    445	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
    446	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
    447	int route_sock = -1, ret = -1;
    448	uint32_t route_seq;
    449
    450	if (switch_ns(nsfd))
    451		return -1;
    452
    453	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
    454		printk("Failed to open netlink route socket in child");
    455		return -1;
    456	}
    457
    458	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
    459		printk("Failed to set ipv4 addr");
    460		goto err;
    461	}
    462
    463	if (link_set_up(route_sock, route_seq++, veth)) {
    464		printk("Failed to bring up %s", veth);
    465		goto err;
    466	}
    467
    468	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
    469		printk("Failed to add tunnel route on %s", veth);
    470		goto err;
    471	}
    472	ret = 0;
    473
    474err:
    475	close(route_sock);
    476	return ret;
    477}
    478
    479#define ALGO_LEN	64
    480enum desc_type {
    481	CREATE_TUNNEL	= 0,
    482	ALLOCATE_SPI,
    483	MONITOR_ACQUIRE,
    484	EXPIRE_STATE,
    485	EXPIRE_POLICY,
    486	SPDINFO_ATTRS,
    487};
    488const char *desc_name[] = {
    489	"create tunnel",
    490	"alloc spi",
    491	"monitor acquire",
    492	"expire state",
    493	"expire policy",
    494	"spdinfo attributes",
    495	""
    496};
    497struct xfrm_desc {
    498	enum desc_type	type;
    499	uint8_t		proto;
    500	char		a_algo[ALGO_LEN];
    501	char		e_algo[ALGO_LEN];
    502	char		c_algo[ALGO_LEN];
    503	char		ae_algo[ALGO_LEN];
    504	unsigned int	icv_len;
    505	/* unsigned key_len; */
    506};
    507
    508enum msg_type {
    509	MSG_ACK		= 0,
    510	MSG_EXIT,
    511	MSG_PING,
    512	MSG_XFRM_PREPARE,
    513	MSG_XFRM_ADD,
    514	MSG_XFRM_DEL,
    515	MSG_XFRM_CLEANUP,
    516};
    517
    518struct test_desc {
    519	enum msg_type type;
    520	union {
    521		struct {
    522			in_addr_t reply_ip;
    523			unsigned int port;
    524		} ping;
    525		struct xfrm_desc xfrm_desc;
    526	} body;
    527};
    528
    529struct test_result {
    530	struct xfrm_desc desc;
    531	unsigned int res;
    532};
    533
    534static void write_test_result(unsigned int res, struct xfrm_desc *d)
    535{
    536	struct test_result tr = {};
    537	ssize_t ret;
    538
    539	tr.desc = *d;
    540	tr.res = res;
    541
    542	ret = write(results_fd[1], &tr, sizeof(tr));
    543	if (ret != sizeof(tr))
    544		pr_err("Failed to write the result in pipe %zd", ret);
    545}
    546
    547static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
    548{
    549	ssize_t bytes = write(fd, msg, sizeof(*msg));
    550
    551	/* Make sure that write/read is atomic to a pipe */
    552	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
    553
    554	if (bytes < 0) {
    555		pr_err("write()");
    556		if (exit_of_fail)
    557			exit(KSFT_FAIL);
    558	}
    559	if (bytes != sizeof(*msg)) {
    560		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
    561		if (exit_of_fail)
    562			exit(KSFT_FAIL);
    563	}
    564}
    565
    566static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
    567{
    568	ssize_t bytes = read(fd, msg, sizeof(*msg));
    569
    570	if (bytes < 0) {
    571		pr_err("read()");
    572		if (exit_of_fail)
    573			exit(KSFT_FAIL);
    574	}
    575	if (bytes != sizeof(*msg)) {
    576		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
    577		if (exit_of_fail)
    578			exit(KSFT_FAIL);
    579	}
    580}
    581
    582static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
    583		unsigned int *server_port, int sock[2])
    584{
    585	struct sockaddr_in server;
    586	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
    587	socklen_t s_len = sizeof(server);
    588
    589	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
    590	if (sock[0] < 0) {
    591		pr_err("socket()");
    592		return -1;
    593	}
    594
    595	server.sin_family	= AF_INET;
    596	server.sin_port		= 0;
    597	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
    598
    599	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
    600		pr_err("bind()");
    601		goto err_close_server;
    602	}
    603
    604	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
    605		pr_err("getsockname()");
    606		goto err_close_server;
    607	}
    608
    609	*server_port = ntohs(server.sin_port);
    610
    611	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
    612		pr_err("setsockopt()");
    613		goto err_close_server;
    614	}
    615
    616	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
    617	if (sock[1] < 0) {
    618		pr_err("socket()");
    619		goto err_close_server;
    620	}
    621
    622	return 0;
    623
    624err_close_server:
    625	close(sock[0]);
    626	return -1;
    627}
    628
    629static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
    630		char *buf, size_t buf_len)
    631{
    632	struct sockaddr_in server;
    633	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
    634	char *sock_buf[buf_len];
    635	ssize_t r_bytes, s_bytes;
    636
    637	server.sin_family	= AF_INET;
    638	server.sin_port		= htons(port);
    639	server.sin_addr.s_addr	= dest_ip;
    640
    641	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
    642	if (s_bytes < 0) {
    643		pr_err("sendto()");
    644		return -1;
    645	} else if (s_bytes != buf_len) {
    646		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
    647		return -1;
    648	}
    649
    650	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
    651	if (r_bytes < 0) {
    652		if (errno != EAGAIN)
    653			pr_err("recv()");
    654		return -1;
    655	} else if (r_bytes == 0) { /* EOF */
    656		printk("EOF on reply to ping");
    657		return -1;
    658	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
    659		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
    660		return -1;
    661	}
    662
    663	return 0;
    664}
    665
    666static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
    667		char *buf, size_t buf_len)
    668{
    669	struct sockaddr_in server;
    670	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
    671	char *sock_buf[buf_len];
    672	ssize_t r_bytes, s_bytes;
    673
    674	server.sin_family	= AF_INET;
    675	server.sin_port		= htons(port);
    676	server.sin_addr.s_addr	= dest_ip;
    677
    678	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
    679	if (r_bytes < 0) {
    680		if (errno != EAGAIN)
    681			pr_err("recv()");
    682		return -1;
    683	}
    684	if (r_bytes == 0) { /* EOF */
    685		printk("EOF on reply to ping");
    686		return -1;
    687	}
    688	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
    689		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
    690		return -1;
    691	}
    692
    693	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
    694	if (s_bytes < 0) {
    695		pr_err("sendto()");
    696		return -1;
    697	} else if (s_bytes != buf_len) {
    698		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
    699		return -1;
    700	}
    701
    702	return 0;
    703}
    704
    705typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
    706		char *buf, size_t buf_len);
    707static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
    708		bool init_side, int d_port, in_addr_t to, ping_f func)
    709{
    710	struct test_desc msg;
    711	unsigned int s_port, i, ping_succeeded = 0;
    712	int ping_sock[2];
    713	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
    714
    715	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
    716		printk("Failed to init ping");
    717		return -1;
    718	}
    719
    720	memset(&msg, 0, sizeof(msg));
    721	msg.type		= MSG_PING;
    722	msg.body.ping.port	= s_port;
    723	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
    724
    725	write_msg(cmd_fd, &msg, 0);
    726	if (init_side) {
    727		/* The other end sends ip to ping */
    728		read_msg(cmd_fd, &msg, 0);
    729		if (msg.type != MSG_PING)
    730			return -1;
    731		to = msg.body.ping.reply_ip;
    732		d_port = msg.body.ping.port;
    733	}
    734
    735	for (i = 0; i < ping_count ; i++) {
    736		struct timespec sleep_time = {
    737			.tv_sec = 0,
    738			.tv_nsec = ping_delay_nsec,
    739		};
    740
    741		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
    742		nanosleep(&sleep_time, 0);
    743	}
    744
    745	close(ping_sock[0]);
    746	close(ping_sock[1]);
    747
    748	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
    749	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
    750
    751	if (ping_succeeded < ping_success) {
    752		printk("ping (%s) %s->%s failed %u/%u times",
    753			init_side ? "send" : "reply", from_str, to_str,
    754			ping_count - ping_succeeded, ping_count);
    755		return -1;
    756	}
    757
    758#ifdef DEBUG
    759	printk("ping (%s) %s->%s succeeded %u/%u times",
    760		init_side ? "send" : "reply", from_str, to_str,
    761		ping_succeeded, ping_count);
    762#endif
    763
    764	return 0;
    765}
    766
    767static int xfrm_fill_key(char *name, char *buf,
    768		size_t buf_len, unsigned int *key_len)
    769{
    770	/* TODO: use set/map instead */
    771	if (strncmp(name, "digest_null", ALGO_LEN) == 0)
    772		*key_len = 0;
    773	else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
    774		*key_len = 0;
    775	else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
    776		*key_len = 64;
    777	else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
    778		*key_len = 128;
    779	else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
    780		*key_len = 128;
    781	else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
    782		*key_len = 128;
    783	else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
    784		*key_len = 128;
    785	else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
    786		*key_len = 128;
    787	else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
    788		*key_len = 160;
    789	else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
    790		*key_len = 160;
    791	else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
    792		*key_len = 192;
    793	else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
    794		*key_len = 256;
    795	else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
    796		*key_len = 256;
    797	else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
    798		*key_len = 256;
    799	else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
    800		*key_len = 256;
    801	else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
    802		*key_len = 288;
    803	else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
    804		*key_len = 384;
    805	else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
    806		*key_len = 448;
    807	else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
    808		*key_len = 512;
    809	else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
    810		*key_len = 160;
    811	else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
    812		*key_len = 160;
    813	else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
    814		*key_len = 152;
    815	else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
    816		*key_len = 224;
    817	else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
    818		*key_len = 224;
    819	else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
    820		*key_len = 216;
    821	else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
    822		*key_len = 288;
    823	else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
    824		*key_len = 288;
    825	else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
    826		*key_len = 280;
    827	else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
    828		*key_len = 0;
    829
    830	if (*key_len > buf_len) {
    831		printk("Can't pack a key - too big for buffer");
    832		return -1;
    833	}
    834
    835	randomize_buffer(buf, *key_len);
    836
    837	return 0;
    838}
    839
    840static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
    841		struct xfrm_desc *desc)
    842{
    843	struct {
    844		union {
    845			struct xfrm_algo	alg;
    846			struct xfrm_algo_aead	aead;
    847			struct xfrm_algo_auth	auth;
    848		} u;
    849		char buf[XFRM_ALGO_KEY_BUF_SIZE];
    850	} alg = {};
    851	size_t alen, elen, clen, aelen;
    852	unsigned short type;
    853
    854	alen = strlen(desc->a_algo);
    855	elen = strlen(desc->e_algo);
    856	clen = strlen(desc->c_algo);
    857	aelen = strlen(desc->ae_algo);
    858
    859	/* Verify desc */
    860	switch (desc->proto) {
    861	case IPPROTO_AH:
    862		if (!alen || elen || clen || aelen) {
    863			printk("BUG: buggy ah desc");
    864			return -1;
    865		}
    866		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
    867		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
    868				sizeof(alg.buf), &alg.u.alg.alg_key_len))
    869			return -1;
    870		type = XFRMA_ALG_AUTH;
    871		break;
    872	case IPPROTO_COMP:
    873		if (!clen || elen || alen || aelen) {
    874			printk("BUG: buggy comp desc");
    875			return -1;
    876		}
    877		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
    878		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
    879				sizeof(alg.buf), &alg.u.alg.alg_key_len))
    880			return -1;
    881		type = XFRMA_ALG_COMP;
    882		break;
    883	case IPPROTO_ESP:
    884		if (!((alen && elen) ^ aelen) || clen) {
    885			printk("BUG: buggy esp desc");
    886			return -1;
    887		}
    888		if (aelen) {
    889			alg.u.aead.alg_icv_len = desc->icv_len;
    890			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
    891			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
    892						sizeof(alg.buf), &alg.u.aead.alg_key_len))
    893				return -1;
    894			type = XFRMA_ALG_AEAD;
    895		} else {
    896
    897			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
    898			type = XFRMA_ALG_CRYPT;
    899			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
    900						sizeof(alg.buf), &alg.u.alg.alg_key_len))
    901				return -1;
    902			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
    903				return -1;
    904
    905			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
    906			type = XFRMA_ALG_AUTH;
    907			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
    908						sizeof(alg.buf), &alg.u.alg.alg_key_len))
    909				return -1;
    910		}
    911		break;
    912	default:
    913		printk("BUG: unknown proto in desc");
    914		return -1;
    915	}
    916
    917	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
    918		return -1;
    919
    920	return 0;
    921}
    922
    923static inline uint32_t gen_spi(struct in_addr src)
    924{
    925	return htonl(inet_lnaof(src));
    926}
    927
    928static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
    929		struct in_addr src, struct in_addr dst,
    930		struct xfrm_desc *desc)
    931{
    932	struct {
    933		struct nlmsghdr		nh;
    934		struct xfrm_usersa_info	info;
    935		char			attrbuf[MAX_PAYLOAD];
    936	} req;
    937
    938	memset(&req, 0, sizeof(req));
    939	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
    940	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
    941	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
    942	req.nh.nlmsg_seq	= seq;
    943
    944	/* Fill selector. */
    945	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
    946	memcpy(&req.info.sel.saddr, &src, sizeof(src));
    947	req.info.sel.family		= AF_INET;
    948	req.info.sel.prefixlen_d	= PREFIX_LEN;
    949	req.info.sel.prefixlen_s	= PREFIX_LEN;
    950
    951	/* Fill id */
    952	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
    953	/* Note: zero-spi cannot be deleted */
    954	req.info.id.spi = spi;
    955	req.info.id.proto	= desc->proto;
    956
    957	memcpy(&req.info.saddr, &src, sizeof(src));
    958
    959	/* Fill lifteme_cfg */
    960	req.info.lft.soft_byte_limit	= XFRM_INF;
    961	req.info.lft.hard_byte_limit	= XFRM_INF;
    962	req.info.lft.soft_packet_limit	= XFRM_INF;
    963	req.info.lft.hard_packet_limit	= XFRM_INF;
    964
    965	req.info.family		= AF_INET;
    966	req.info.mode		= XFRM_MODE_TUNNEL;
    967
    968	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
    969		return -1;
    970
    971	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
    972		pr_err("send()");
    973		return -1;
    974	}
    975
    976	return netlink_check_answer(xfrm_sock);
    977}
    978
    979static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
    980		struct in_addr src, struct in_addr dst,
    981		struct xfrm_desc *desc)
    982{
    983	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
    984		return false;
    985
    986	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
    987		return false;
    988
    989	if (info->sel.family != AF_INET					||
    990			info->sel.prefixlen_d != PREFIX_LEN		||
    991			info->sel.prefixlen_s != PREFIX_LEN)
    992		return false;
    993
    994	if (info->id.spi != spi || info->id.proto != desc->proto)
    995		return false;
    996
    997	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
    998		return false;
    999
   1000	if (memcmp(&info->saddr, &src, sizeof(src)))
   1001		return false;
   1002
   1003	if (info->lft.soft_byte_limit != XFRM_INF			||
   1004			info->lft.hard_byte_limit != XFRM_INF		||
   1005			info->lft.soft_packet_limit != XFRM_INF		||
   1006			info->lft.hard_packet_limit != XFRM_INF)
   1007		return false;
   1008
   1009	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
   1010		return false;
   1011
   1012	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
   1013
   1014	return true;
   1015}
   1016
   1017static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
   1018		struct in_addr src, struct in_addr dst,
   1019		struct xfrm_desc *desc)
   1020{
   1021	struct {
   1022		struct nlmsghdr		nh;
   1023		char			attrbuf[MAX_PAYLOAD];
   1024	} req;
   1025	struct {
   1026		struct nlmsghdr		nh;
   1027		union {
   1028			struct xfrm_usersa_info	info;
   1029			int error;
   1030		};
   1031		char			attrbuf[MAX_PAYLOAD];
   1032	} answer;
   1033	struct xfrm_address_filter filter = {};
   1034	bool found = false;
   1035
   1036
   1037	memset(&req, 0, sizeof(req));
   1038	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
   1039	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
   1040	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
   1041	req.nh.nlmsg_seq	= seq;
   1042
   1043	/*
   1044	 * Add dump filter by source address as there may be other tunnels
   1045	 * in this netns (if tests run in parallel).
   1046	 */
   1047	filter.family = AF_INET;
   1048	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
   1049	memcpy(&filter.saddr, &src, sizeof(src));
   1050	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
   1051				&filter, sizeof(filter)))
   1052		return -1;
   1053
   1054	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1055		pr_err("send()");
   1056		return -1;
   1057	}
   1058
   1059	while (1) {
   1060		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
   1061			pr_err("recv()");
   1062			return -1;
   1063		}
   1064		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
   1065			printk("NLMSG_ERROR: %d: %s",
   1066				answer.error, strerror(-answer.error));
   1067			return -1;
   1068		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
   1069			if (found)
   1070				return 0;
   1071			printk("didn't find allocated xfrm state in dump");
   1072			return -1;
   1073		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
   1074			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
   1075				found = true;
   1076		}
   1077	}
   1078}
   1079
   1080static int xfrm_set(int xfrm_sock, uint32_t *seq,
   1081		struct in_addr src, struct in_addr dst,
   1082		struct in_addr tunsrc, struct in_addr tundst,
   1083		struct xfrm_desc *desc)
   1084{
   1085	int err;
   1086
   1087	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
   1088	if (err) {
   1089		printk("Failed to add xfrm state");
   1090		return -1;
   1091	}
   1092
   1093	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
   1094	if (err) {
   1095		printk("Failed to add xfrm state");
   1096		return -1;
   1097	}
   1098
   1099	/* Check dumps for XFRM_MSG_GETSA */
   1100	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
   1101	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
   1102	if (err) {
   1103		printk("Failed to check xfrm state");
   1104		return -1;
   1105	}
   1106
   1107	return 0;
   1108}
   1109
   1110static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
   1111		struct in_addr src, struct in_addr dst, uint8_t dir,
   1112		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
   1113{
   1114	struct {
   1115		struct nlmsghdr			nh;
   1116		struct xfrm_userpolicy_info	info;
   1117		char				attrbuf[MAX_PAYLOAD];
   1118	} req;
   1119	struct xfrm_user_tmpl tmpl;
   1120
   1121	memset(&req, 0, sizeof(req));
   1122	memset(&tmpl, 0, sizeof(tmpl));
   1123	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
   1124	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
   1125	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1126	req.nh.nlmsg_seq	= seq;
   1127
   1128	/* Fill selector. */
   1129	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
   1130	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
   1131	req.info.sel.family		= AF_INET;
   1132	req.info.sel.prefixlen_d	= PREFIX_LEN;
   1133	req.info.sel.prefixlen_s	= PREFIX_LEN;
   1134
   1135	/* Fill lifteme_cfg */
   1136	req.info.lft.soft_byte_limit	= XFRM_INF;
   1137	req.info.lft.hard_byte_limit	= XFRM_INF;
   1138	req.info.lft.soft_packet_limit	= XFRM_INF;
   1139	req.info.lft.hard_packet_limit	= XFRM_INF;
   1140
   1141	req.info.dir = dir;
   1142
   1143	/* Fill tmpl */
   1144	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
   1145	/* Note: zero-spi cannot be deleted */
   1146	tmpl.id.spi = spi;
   1147	tmpl.id.proto	= proto;
   1148	tmpl.family	= AF_INET;
   1149	memcpy(&tmpl.saddr, &src, sizeof(src));
   1150	tmpl.mode	= XFRM_MODE_TUNNEL;
   1151	tmpl.aalgos = (~(uint32_t)0);
   1152	tmpl.ealgos = (~(uint32_t)0);
   1153	tmpl.calgos = (~(uint32_t)0);
   1154
   1155	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
   1156		return -1;
   1157
   1158	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1159		pr_err("send()");
   1160		return -1;
   1161	}
   1162
   1163	return netlink_check_answer(xfrm_sock);
   1164}
   1165
   1166static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
   1167		struct in_addr src, struct in_addr dst,
   1168		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
   1169{
   1170	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
   1171				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
   1172		printk("Failed to add xfrm policy");
   1173		return -1;
   1174	}
   1175
   1176	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
   1177				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
   1178		printk("Failed to add xfrm policy");
   1179		return -1;
   1180	}
   1181
   1182	return 0;
   1183}
   1184
   1185static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
   1186		struct in_addr src, struct in_addr dst, uint8_t dir,
   1187		struct in_addr tunsrc, struct in_addr tundst)
   1188{
   1189	struct {
   1190		struct nlmsghdr			nh;
   1191		struct xfrm_userpolicy_id	id;
   1192		char				attrbuf[MAX_PAYLOAD];
   1193	} req;
   1194
   1195	memset(&req, 0, sizeof(req));
   1196	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
   1197	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
   1198	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1199	req.nh.nlmsg_seq	= seq;
   1200
   1201	/* Fill id */
   1202	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
   1203	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
   1204	req.id.sel.family		= AF_INET;
   1205	req.id.sel.prefixlen_d		= PREFIX_LEN;
   1206	req.id.sel.prefixlen_s		= PREFIX_LEN;
   1207	req.id.dir = dir;
   1208
   1209	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1210		pr_err("send()");
   1211		return -1;
   1212	}
   1213
   1214	return netlink_check_answer(xfrm_sock);
   1215}
   1216
   1217static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
   1218		struct in_addr src, struct in_addr dst,
   1219		struct in_addr tunsrc, struct in_addr tundst)
   1220{
   1221	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
   1222				XFRM_POLICY_OUT, tunsrc, tundst)) {
   1223		printk("Failed to add xfrm policy");
   1224		return -1;
   1225	}
   1226
   1227	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
   1228				XFRM_POLICY_IN, tunsrc, tundst)) {
   1229		printk("Failed to add xfrm policy");
   1230		return -1;
   1231	}
   1232
   1233	return 0;
   1234}
   1235
   1236static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
   1237		struct in_addr src, struct in_addr dst, uint8_t proto)
   1238{
   1239	struct {
   1240		struct nlmsghdr		nh;
   1241		struct xfrm_usersa_id	id;
   1242		char			attrbuf[MAX_PAYLOAD];
   1243	} req;
   1244	xfrm_address_t saddr = {};
   1245
   1246	memset(&req, 0, sizeof(req));
   1247	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
   1248	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
   1249	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1250	req.nh.nlmsg_seq	= seq;
   1251
   1252	memcpy(&req.id.daddr, &dst, sizeof(dst));
   1253	req.id.family		= AF_INET;
   1254	req.id.proto		= proto;
   1255	/* Note: zero-spi cannot be deleted */
   1256	req.id.spi = spi;
   1257
   1258	memcpy(&saddr, &src, sizeof(src));
   1259	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
   1260		return -1;
   1261
   1262	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1263		pr_err("send()");
   1264		return -1;
   1265	}
   1266
   1267	return netlink_check_answer(xfrm_sock);
   1268}
   1269
   1270static int xfrm_delete(int xfrm_sock, uint32_t *seq,
   1271		struct in_addr src, struct in_addr dst,
   1272		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
   1273{
   1274	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
   1275		printk("Failed to remove xfrm state");
   1276		return -1;
   1277	}
   1278
   1279	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
   1280		printk("Failed to remove xfrm state");
   1281		return -1;
   1282	}
   1283
   1284	return 0;
   1285}
   1286
   1287static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
   1288		uint32_t spi, uint8_t proto)
   1289{
   1290	struct {
   1291		struct nlmsghdr			nh;
   1292		struct xfrm_userspi_info	spi;
   1293	} req;
   1294	struct {
   1295		struct nlmsghdr			nh;
   1296		union {
   1297			struct xfrm_usersa_info	info;
   1298			int error;
   1299		};
   1300	} answer;
   1301
   1302	memset(&req, 0, sizeof(req));
   1303	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
   1304	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
   1305	req.nh.nlmsg_flags	= NLM_F_REQUEST;
   1306	req.nh.nlmsg_seq	= (*seq)++;
   1307
   1308	req.spi.info.family	= AF_INET;
   1309	req.spi.min		= spi;
   1310	req.spi.max		= spi;
   1311	req.spi.info.id.proto	= proto;
   1312
   1313	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1314		pr_err("send()");
   1315		return KSFT_FAIL;
   1316	}
   1317
   1318	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
   1319		pr_err("recv()");
   1320		return KSFT_FAIL;
   1321	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
   1322		uint32_t new_spi = htonl(answer.info.id.spi);
   1323
   1324		if (new_spi != spi) {
   1325			printk("allocated spi is different from requested: %#x != %#x",
   1326					new_spi, spi);
   1327			return KSFT_FAIL;
   1328		}
   1329		return KSFT_PASS;
   1330	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
   1331		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
   1332		return KSFT_FAIL;
   1333	}
   1334
   1335	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
   1336	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
   1337}
   1338
   1339static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
   1340{
   1341	struct sockaddr_nl snl = {};
   1342	socklen_t addr_len;
   1343	int ret = -1;
   1344
   1345	snl.nl_family = AF_NETLINK;
   1346	snl.nl_groups = groups;
   1347
   1348	if (netlink_sock(sock, seq, proto)) {
   1349		printk("Failed to open xfrm netlink socket");
   1350		return -1;
   1351	}
   1352
   1353	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
   1354		pr_err("bind()");
   1355		goto out_close;
   1356	}
   1357
   1358	addr_len = sizeof(snl);
   1359	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
   1360		pr_err("getsockname()");
   1361		goto out_close;
   1362	}
   1363	if (addr_len != sizeof(snl)) {
   1364		printk("Wrong address length %d", addr_len);
   1365		goto out_close;
   1366	}
   1367	if (snl.nl_family != AF_NETLINK) {
   1368		printk("Wrong address family %d", snl.nl_family);
   1369		goto out_close;
   1370	}
   1371	return 0;
   1372
   1373out_close:
   1374	close(*sock);
   1375	return ret;
   1376}
   1377
   1378static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
   1379{
   1380	struct {
   1381		struct nlmsghdr nh;
   1382		union {
   1383			struct xfrm_user_acquire acq;
   1384			int error;
   1385		};
   1386		char attrbuf[MAX_PAYLOAD];
   1387	} req;
   1388	struct xfrm_user_tmpl xfrm_tmpl = {};
   1389	int xfrm_listen = -1, ret = KSFT_FAIL;
   1390	uint32_t seq_listen;
   1391
   1392	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
   1393		return KSFT_FAIL;
   1394
   1395	memset(&req, 0, sizeof(req));
   1396	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
   1397	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
   1398	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1399	req.nh.nlmsg_seq	= (*seq)++;
   1400
   1401	req.acq.policy.sel.family	= AF_INET;
   1402	req.acq.aalgos	= 0xfeed;
   1403	req.acq.ealgos	= 0xbaad;
   1404	req.acq.calgos	= 0xbabe;
   1405
   1406	xfrm_tmpl.family = AF_INET;
   1407	xfrm_tmpl.id.proto = IPPROTO_ESP;
   1408	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
   1409		goto out_close;
   1410
   1411	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1412		pr_err("send()");
   1413		goto out_close;
   1414	}
   1415
   1416	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
   1417		pr_err("recv()");
   1418		goto out_close;
   1419	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
   1420		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
   1421		goto out_close;
   1422	}
   1423
   1424	if (req.error) {
   1425		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
   1426		ret = req.error;
   1427		goto out_close;
   1428	}
   1429
   1430	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
   1431		pr_err("recv()");
   1432		goto out_close;
   1433	}
   1434
   1435	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
   1436			|| req.acq.calgos != 0xbabe) {
   1437		printk("xfrm_user_acquire has changed  %x %x %x",
   1438				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
   1439		goto out_close;
   1440	}
   1441
   1442	ret = KSFT_PASS;
   1443out_close:
   1444	close(xfrm_listen);
   1445	return ret;
   1446}
   1447
   1448static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
   1449		unsigned int nr, struct xfrm_desc *desc)
   1450{
   1451	struct {
   1452		struct nlmsghdr nh;
   1453		union {
   1454			struct xfrm_user_expire expire;
   1455			int error;
   1456		};
   1457	} req;
   1458	struct in_addr src, dst;
   1459	int xfrm_listen = -1, ret = KSFT_FAIL;
   1460	uint32_t seq_listen;
   1461
   1462	src = inet_makeaddr(INADDR_B, child_ip(nr));
   1463	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
   1464
   1465	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
   1466		printk("Failed to add xfrm state");
   1467		return KSFT_FAIL;
   1468	}
   1469
   1470	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
   1471		return KSFT_FAIL;
   1472
   1473	memset(&req, 0, sizeof(req));
   1474	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
   1475	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
   1476	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1477	req.nh.nlmsg_seq	= (*seq)++;
   1478
   1479	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
   1480	req.expire.state.id.spi		= gen_spi(src);
   1481	req.expire.state.id.proto	= desc->proto;
   1482	req.expire.state.family		= AF_INET;
   1483	req.expire.hard			= 0xff;
   1484
   1485	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1486		pr_err("send()");
   1487		goto out_close;
   1488	}
   1489
   1490	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
   1491		pr_err("recv()");
   1492		goto out_close;
   1493	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
   1494		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
   1495		goto out_close;
   1496	}
   1497
   1498	if (req.error) {
   1499		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
   1500		ret = req.error;
   1501		goto out_close;
   1502	}
   1503
   1504	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
   1505		pr_err("recv()");
   1506		goto out_close;
   1507	}
   1508
   1509	if (req.expire.hard != 0x1) {
   1510		printk("expire.hard is not set: %x", req.expire.hard);
   1511		goto out_close;
   1512	}
   1513
   1514	ret = KSFT_PASS;
   1515out_close:
   1516	close(xfrm_listen);
   1517	return ret;
   1518}
   1519
   1520static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
   1521		unsigned int nr, struct xfrm_desc *desc)
   1522{
   1523	struct {
   1524		struct nlmsghdr nh;
   1525		union {
   1526			struct xfrm_user_polexpire expire;
   1527			int error;
   1528		};
   1529	} req;
   1530	struct in_addr src, dst, tunsrc, tundst;
   1531	int xfrm_listen = -1, ret = KSFT_FAIL;
   1532	uint32_t seq_listen;
   1533
   1534	src = inet_makeaddr(INADDR_B, child_ip(nr));
   1535	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
   1536	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
   1537	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
   1538
   1539	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
   1540				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
   1541		printk("Failed to add xfrm policy");
   1542		return KSFT_FAIL;
   1543	}
   1544
   1545	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
   1546		return KSFT_FAIL;
   1547
   1548	memset(&req, 0, sizeof(req));
   1549	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
   1550	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
   1551	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1552	req.nh.nlmsg_seq	= (*seq)++;
   1553
   1554	/* Fill selector. */
   1555	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
   1556	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
   1557	req.expire.pol.sel.family	= AF_INET;
   1558	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
   1559	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
   1560	req.expire.pol.dir		= XFRM_POLICY_OUT;
   1561	req.expire.hard			= 0xff;
   1562
   1563	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1564		pr_err("send()");
   1565		goto out_close;
   1566	}
   1567
   1568	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
   1569		pr_err("recv()");
   1570		goto out_close;
   1571	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
   1572		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
   1573		goto out_close;
   1574	}
   1575
   1576	if (req.error) {
   1577		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
   1578		ret = req.error;
   1579		goto out_close;
   1580	}
   1581
   1582	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
   1583		pr_err("recv()");
   1584		goto out_close;
   1585	}
   1586
   1587	if (req.expire.hard != 0x1) {
   1588		printk("expire.hard is not set: %x", req.expire.hard);
   1589		goto out_close;
   1590	}
   1591
   1592	ret = KSFT_PASS;
   1593out_close:
   1594	close(xfrm_listen);
   1595	return ret;
   1596}
   1597
   1598static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
   1599		unsigned thresh4_l, unsigned thresh4_r,
   1600		unsigned thresh6_l, unsigned thresh6_r,
   1601		bool add_bad_attr)
   1602
   1603{
   1604	struct {
   1605		struct nlmsghdr		nh;
   1606		union {
   1607			uint32_t	unused;
   1608			int		error;
   1609		};
   1610		char			attrbuf[MAX_PAYLOAD];
   1611	} req;
   1612	struct xfrmu_spdhthresh thresh;
   1613
   1614	memset(&req, 0, sizeof(req));
   1615	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
   1616	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
   1617	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
   1618	req.nh.nlmsg_seq	= (*seq)++;
   1619
   1620	thresh.lbits = thresh4_l;
   1621	thresh.rbits = thresh4_r;
   1622	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
   1623		return -1;
   1624
   1625	thresh.lbits = thresh6_l;
   1626	thresh.rbits = thresh6_r;
   1627	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
   1628		return -1;
   1629
   1630	if (add_bad_attr) {
   1631		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
   1632		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
   1633			pr_err("adding attribute failed: no space");
   1634			return -1;
   1635		}
   1636	}
   1637
   1638	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1639		pr_err("send()");
   1640		return -1;
   1641	}
   1642
   1643	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
   1644		pr_err("recv()");
   1645		return -1;
   1646	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
   1647		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
   1648		return -1;
   1649	}
   1650
   1651	if (req.error) {
   1652		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
   1653		return -1;
   1654	}
   1655
   1656	return 0;
   1657}
   1658
   1659static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
   1660{
   1661	struct {
   1662		struct nlmsghdr			nh;
   1663		union {
   1664			uint32_t	unused;
   1665			int		error;
   1666		};
   1667		char			attrbuf[MAX_PAYLOAD];
   1668	} req;
   1669
   1670	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
   1671		pr_err("Can't set SPD HTHRESH");
   1672		return KSFT_FAIL;
   1673	}
   1674
   1675	memset(&req, 0, sizeof(req));
   1676
   1677	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
   1678	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
   1679	req.nh.nlmsg_flags	= NLM_F_REQUEST;
   1680	req.nh.nlmsg_seq	= (*seq)++;
   1681	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
   1682		pr_err("send()");
   1683		return KSFT_FAIL;
   1684	}
   1685
   1686	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
   1687		pr_err("recv()");
   1688		return KSFT_FAIL;
   1689	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
   1690		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
   1691		struct rtattr *attr = (void *)req.attrbuf;
   1692		int got_thresh = 0;
   1693
   1694		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
   1695			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
   1696				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
   1697
   1698				got_thresh++;
   1699				if (t->lbits != 32 || t->rbits != 31) {
   1700					pr_err("thresh differ: %u, %u",
   1701							t->lbits, t->rbits);
   1702					return KSFT_FAIL;
   1703				}
   1704			}
   1705			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
   1706				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
   1707
   1708				got_thresh++;
   1709				if (t->lbits != 120 || t->rbits != 16) {
   1710					pr_err("thresh differ: %u, %u",
   1711							t->lbits, t->rbits);
   1712					return KSFT_FAIL;
   1713				}
   1714			}
   1715		}
   1716		if (got_thresh != 2) {
   1717			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
   1718			return KSFT_FAIL;
   1719		}
   1720	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
   1721		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
   1722		return KSFT_FAIL;
   1723	} else {
   1724		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
   1725		return -1;
   1726	}
   1727
   1728	/* Restore the default */
   1729	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
   1730		pr_err("Can't restore SPD HTHRESH");
   1731		return KSFT_FAIL;
   1732	}
   1733
   1734	/*
   1735	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
   1736	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
   1737	 * (type > maxtype). nla_parse_depricated_strict() would enforce
   1738	 * it. Or even stricter nla_parse().
   1739	 * Right now it's not expected to fail, but to be ignored.
   1740	 */
   1741	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
   1742		return KSFT_PASS;
   1743
   1744	return KSFT_PASS;
   1745}
   1746
   1747static int child_serv(int xfrm_sock, uint32_t *seq,
   1748		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
   1749{
   1750	struct in_addr src, dst, tunsrc, tundst;
   1751	struct test_desc msg;
   1752	int ret = KSFT_FAIL;
   1753
   1754	src = inet_makeaddr(INADDR_B, child_ip(nr));
   1755	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
   1756	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
   1757	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
   1758
   1759	/* UDP pinging without xfrm */
   1760	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
   1761		printk("ping failed before setting xfrm");
   1762		return KSFT_FAIL;
   1763	}
   1764
   1765	memset(&msg, 0, sizeof(msg));
   1766	msg.type = MSG_XFRM_PREPARE;
   1767	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
   1768	write_msg(cmd_fd, &msg, 1);
   1769
   1770	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
   1771		printk("failed to prepare xfrm");
   1772		goto cleanup;
   1773	}
   1774
   1775	memset(&msg, 0, sizeof(msg));
   1776	msg.type = MSG_XFRM_ADD;
   1777	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
   1778	write_msg(cmd_fd, &msg, 1);
   1779	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
   1780		printk("failed to set xfrm");
   1781		goto delete;
   1782	}
   1783
   1784	/* UDP pinging with xfrm tunnel */
   1785	if (do_ping(cmd_fd, buf, page_size, tunsrc,
   1786				true, 0, 0, udp_ping_send)) {
   1787		printk("ping failed for xfrm");
   1788		goto delete;
   1789	}
   1790
   1791	ret = KSFT_PASS;
   1792delete:
   1793	/* xfrm delete */
   1794	memset(&msg, 0, sizeof(msg));
   1795	msg.type = MSG_XFRM_DEL;
   1796	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
   1797	write_msg(cmd_fd, &msg, 1);
   1798
   1799	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
   1800		printk("failed ping to remove xfrm");
   1801		ret = KSFT_FAIL;
   1802	}
   1803
   1804cleanup:
   1805	memset(&msg, 0, sizeof(msg));
   1806	msg.type = MSG_XFRM_CLEANUP;
   1807	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
   1808	write_msg(cmd_fd, &msg, 1);
   1809	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
   1810		printk("failed ping to cleanup xfrm");
   1811		ret = KSFT_FAIL;
   1812	}
   1813	return ret;
   1814}
   1815
   1816static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
   1817{
   1818	struct xfrm_desc desc;
   1819	struct test_desc msg;
   1820	int xfrm_sock = -1;
   1821	uint32_t seq;
   1822
   1823	if (switch_ns(nsfd_childa))
   1824		exit(KSFT_FAIL);
   1825
   1826	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
   1827		printk("Failed to open xfrm netlink socket");
   1828		exit(KSFT_FAIL);
   1829	}
   1830
   1831	/* Check that seq sock is ready, just for sure. */
   1832	memset(&msg, 0, sizeof(msg));
   1833	msg.type = MSG_ACK;
   1834	write_msg(cmd_fd, &msg, 1);
   1835	read_msg(cmd_fd, &msg, 1);
   1836	if (msg.type != MSG_ACK) {
   1837		printk("Ack failed");
   1838		exit(KSFT_FAIL);
   1839	}
   1840
   1841	for (;;) {
   1842		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
   1843		int ret;
   1844
   1845		if (received == 0) /* EOF */
   1846			break;
   1847
   1848		if (received != sizeof(desc)) {
   1849			pr_err("read() returned %zd", received);
   1850			exit(KSFT_FAIL);
   1851		}
   1852
   1853		switch (desc.type) {
   1854		case CREATE_TUNNEL:
   1855			ret = child_serv(xfrm_sock, &seq, nr,
   1856					 cmd_fd, buf, &desc);
   1857			break;
   1858		case ALLOCATE_SPI:
   1859			ret = xfrm_state_allocspi(xfrm_sock, &seq,
   1860						  -1, desc.proto);
   1861			break;
   1862		case MONITOR_ACQUIRE:
   1863			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
   1864			break;
   1865		case EXPIRE_STATE:
   1866			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
   1867			break;
   1868		case EXPIRE_POLICY:
   1869			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
   1870			break;
   1871		case SPDINFO_ATTRS:
   1872			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
   1873			break;
   1874		default:
   1875			printk("Unknown desc type %d", desc.type);
   1876			exit(KSFT_FAIL);
   1877		}
   1878		write_test_result(ret, &desc);
   1879	}
   1880
   1881	close(xfrm_sock);
   1882
   1883	msg.type = MSG_EXIT;
   1884	write_msg(cmd_fd, &msg, 1);
   1885	exit(KSFT_PASS);
   1886}
   1887
   1888static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
   1889		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
   1890{
   1891	struct in_addr src, dst, tunsrc, tundst;
   1892	bool tun_reply;
   1893	struct xfrm_desc *desc = &msg->body.xfrm_desc;
   1894
   1895	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
   1896	dst = inet_makeaddr(INADDR_B, child_ip(nr));
   1897	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
   1898	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
   1899
   1900	switch (msg->type) {
   1901	case MSG_EXIT:
   1902		exit(KSFT_PASS);
   1903	case MSG_ACK:
   1904		write_msg(cmd_fd, msg, 1);
   1905		break;
   1906	case MSG_PING:
   1907		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
   1908		/* UDP pinging without xfrm */
   1909		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
   1910				false, msg->body.ping.port,
   1911				msg->body.ping.reply_ip, udp_ping_reply)) {
   1912			printk("ping failed before setting xfrm");
   1913		}
   1914		break;
   1915	case MSG_XFRM_PREPARE:
   1916		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
   1917					desc->proto)) {
   1918			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
   1919			printk("failed to prepare xfrm");
   1920		}
   1921		break;
   1922	case MSG_XFRM_ADD:
   1923		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
   1924			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
   1925			printk("failed to set xfrm");
   1926		}
   1927		break;
   1928	case MSG_XFRM_DEL:
   1929		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
   1930					desc->proto)) {
   1931			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
   1932			printk("failed to remove xfrm");
   1933		}
   1934		break;
   1935	case MSG_XFRM_CLEANUP:
   1936		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
   1937			printk("failed to cleanup xfrm");
   1938		}
   1939		break;
   1940	default:
   1941		printk("got unknown msg type %d", msg->type);
   1942	}
   1943}
   1944
   1945static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
   1946{
   1947	struct test_desc msg;
   1948	int xfrm_sock = -1;
   1949	uint32_t seq;
   1950
   1951	if (switch_ns(nsfd_childb))
   1952		exit(KSFT_FAIL);
   1953
   1954	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
   1955		printk("Failed to open xfrm netlink socket");
   1956		exit(KSFT_FAIL);
   1957	}
   1958
   1959	do {
   1960		read_msg(cmd_fd, &msg, 1);
   1961		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
   1962	} while (1);
   1963
   1964	close(xfrm_sock);
   1965	exit(KSFT_FAIL);
   1966}
   1967
   1968static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
   1969{
   1970	int cmd_sock[2];
   1971	void *data_map;
   1972	pid_t child;
   1973
   1974	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
   1975		return -1;
   1976
   1977	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
   1978		return -1;
   1979
   1980	child = fork();
   1981	if (child < 0) {
   1982		pr_err("fork()");
   1983		return -1;
   1984	} else if (child) {
   1985		/* in parent - selftest */
   1986		return switch_ns(nsfd_parent);
   1987	}
   1988
   1989	if (close(test_desc_fd[1])) {
   1990		pr_err("close()");
   1991		return -1;
   1992	}
   1993
   1994	/* child */
   1995	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
   1996			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
   1997	if (data_map == MAP_FAILED) {
   1998		pr_err("mmap()");
   1999		return -1;
   2000	}
   2001
   2002	randomize_buffer(data_map, page_size);
   2003
   2004	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
   2005		pr_err("socketpair()");
   2006		return -1;
   2007	}
   2008
   2009	child = fork();
   2010	if (child < 0) {
   2011		pr_err("fork()");
   2012		return -1;
   2013	} else if (child) {
   2014		if (close(cmd_sock[0])) {
   2015			pr_err("close()");
   2016			return -1;
   2017		}
   2018		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
   2019	}
   2020	if (close(cmd_sock[1])) {
   2021		pr_err("close()");
   2022		return -1;
   2023	}
   2024	return grand_child_f(nr, cmd_sock[0], data_map);
   2025}
   2026
   2027static void exit_usage(char **argv)
   2028{
   2029	printk("Usage: %s [nr_process]", argv[0]);
   2030	exit(KSFT_FAIL);
   2031}
   2032
   2033static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
   2034{
   2035	ssize_t ret;
   2036
   2037	ret = write(test_desc_fd, desc, sizeof(*desc));
   2038
   2039	if (ret == sizeof(*desc))
   2040		return 0;
   2041
   2042	pr_err("Writing test's desc failed %ld", ret);
   2043
   2044	return -1;
   2045}
   2046
   2047static int write_desc(int proto, int test_desc_fd,
   2048		char *a, char *e, char *c, char *ae)
   2049{
   2050	struct xfrm_desc desc = {};
   2051
   2052	desc.type = CREATE_TUNNEL;
   2053	desc.proto = proto;
   2054
   2055	if (a)
   2056		strncpy(desc.a_algo, a, ALGO_LEN - 1);
   2057	if (e)
   2058		strncpy(desc.e_algo, e, ALGO_LEN - 1);
   2059	if (c)
   2060		strncpy(desc.c_algo, c, ALGO_LEN - 1);
   2061	if (ae)
   2062		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
   2063
   2064	return __write_desc(test_desc_fd, &desc);
   2065}
   2066
   2067int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
   2068char *ah_list[] = {
   2069	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
   2070	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
   2071	"xcbc(aes)", "cmac(aes)"
   2072};
   2073char *comp_list[] = {
   2074	"deflate",
   2075#if 0
   2076	/* No compression backend realization */
   2077	"lzs", "lzjh"
   2078#endif
   2079};
   2080char *e_list[] = {
   2081	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
   2082	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
   2083	"cbc(twofish)", "rfc3686(ctr(aes))"
   2084};
   2085char *ae_list[] = {
   2086#if 0
   2087	/* not implemented */
   2088	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
   2089	"rfc7539esp(chacha20,poly1305)"
   2090#endif
   2091};
   2092
   2093const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
   2094				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
   2095				+ ARRAY_SIZE(ae_list);
   2096
   2097static int write_proto_plan(int fd, int proto)
   2098{
   2099	unsigned int i;
   2100
   2101	switch (proto) {
   2102	case IPPROTO_AH:
   2103		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
   2104			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
   2105				return -1;
   2106		}
   2107		break;
   2108	case IPPROTO_COMP:
   2109		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
   2110			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
   2111				return -1;
   2112		}
   2113		break;
   2114	case IPPROTO_ESP:
   2115		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
   2116			int j;
   2117
   2118			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
   2119				if (write_desc(proto, fd, ah_list[i],
   2120							e_list[j], 0, 0))
   2121					return -1;
   2122			}
   2123		}
   2124		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
   2125			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
   2126				return -1;
   2127		}
   2128		break;
   2129	default:
   2130		printk("BUG: Specified unknown proto %d", proto);
   2131		return -1;
   2132	}
   2133
   2134	return 0;
   2135}
   2136
   2137/*
   2138 * Some structures in xfrm uapi header differ in size between
   2139 * 64-bit and 32-bit ABI:
   2140 *
   2141 *             32-bit UABI               |            64-bit UABI
   2142 *  -------------------------------------|-------------------------------------
   2143 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
   2144 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
   2145 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
   2146 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
   2147 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
   2148 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
   2149 *
   2150 * Check the affected by the UABI difference structures.
   2151 * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
   2152 * which needs to be correctly copied, but not translated.
   2153 */
   2154const unsigned int compat_plan = 5;
   2155static int write_compat_struct_tests(int test_desc_fd)
   2156{
   2157	struct xfrm_desc desc = {};
   2158
   2159	desc.type = ALLOCATE_SPI;
   2160	desc.proto = IPPROTO_AH;
   2161	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
   2162
   2163	if (__write_desc(test_desc_fd, &desc))
   2164		return -1;
   2165
   2166	desc.type = MONITOR_ACQUIRE;
   2167	if (__write_desc(test_desc_fd, &desc))
   2168		return -1;
   2169
   2170	desc.type = EXPIRE_STATE;
   2171	if (__write_desc(test_desc_fd, &desc))
   2172		return -1;
   2173
   2174	desc.type = EXPIRE_POLICY;
   2175	if (__write_desc(test_desc_fd, &desc))
   2176		return -1;
   2177
   2178	desc.type = SPDINFO_ATTRS;
   2179	if (__write_desc(test_desc_fd, &desc))
   2180		return -1;
   2181
   2182	return 0;
   2183}
   2184
   2185static int write_test_plan(int test_desc_fd)
   2186{
   2187	unsigned int i;
   2188	pid_t child;
   2189
   2190	child = fork();
   2191	if (child < 0) {
   2192		pr_err("fork()");
   2193		return -1;
   2194	}
   2195	if (child) {
   2196		if (close(test_desc_fd))
   2197			printk("close(): %m");
   2198		return 0;
   2199	}
   2200
   2201	if (write_compat_struct_tests(test_desc_fd))
   2202		exit(KSFT_FAIL);
   2203
   2204	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
   2205		if (write_proto_plan(test_desc_fd, proto_list[i]))
   2206			exit(KSFT_FAIL);
   2207	}
   2208
   2209	exit(KSFT_PASS);
   2210}
   2211
   2212static int children_cleanup(void)
   2213{
   2214	unsigned ret = KSFT_PASS;
   2215
   2216	while (1) {
   2217		int status;
   2218		pid_t p = wait(&status);
   2219
   2220		if ((p < 0) && errno == ECHILD)
   2221			break;
   2222
   2223		if (p < 0) {
   2224			pr_err("wait()");
   2225			return KSFT_FAIL;
   2226		}
   2227
   2228		if (!WIFEXITED(status)) {
   2229			ret = KSFT_FAIL;
   2230			continue;
   2231		}
   2232
   2233		if (WEXITSTATUS(status) == KSFT_FAIL)
   2234			ret = KSFT_FAIL;
   2235	}
   2236
   2237	return ret;
   2238}
   2239
   2240typedef void (*print_res)(const char *, ...);
   2241
   2242static int check_results(void)
   2243{
   2244	struct test_result tr = {};
   2245	struct xfrm_desc *d = &tr.desc;
   2246	int ret = KSFT_PASS;
   2247
   2248	while (1) {
   2249		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
   2250		print_res result;
   2251
   2252		if (received == 0) /* EOF */
   2253			break;
   2254
   2255		if (received != sizeof(tr)) {
   2256			pr_err("read() returned %zd", received);
   2257			return KSFT_FAIL;
   2258		}
   2259
   2260		switch (tr.res) {
   2261		case KSFT_PASS:
   2262			result = ksft_test_result_pass;
   2263			break;
   2264		case KSFT_FAIL:
   2265		default:
   2266			result = ksft_test_result_fail;
   2267			ret = KSFT_FAIL;
   2268		}
   2269
   2270		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
   2271		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
   2272		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
   2273	}
   2274
   2275	return ret;
   2276}
   2277
   2278int main(int argc, char **argv)
   2279{
   2280	unsigned int nr_process = 1;
   2281	int route_sock = -1, ret = KSFT_SKIP;
   2282	int test_desc_fd[2];
   2283	uint32_t route_seq;
   2284	unsigned int i;
   2285
   2286	if (argc > 2)
   2287		exit_usage(argv);
   2288
   2289	if (argc > 1) {
   2290		char *endptr;
   2291
   2292		errno = 0;
   2293		nr_process = strtol(argv[1], &endptr, 10);
   2294		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
   2295				|| (errno != 0 && nr_process == 0)
   2296				|| (endptr == argv[1]) || (*endptr != '\0')) {
   2297			printk("Failed to parse [nr_process]");
   2298			exit_usage(argv);
   2299		}
   2300
   2301		if (nr_process > MAX_PROCESSES || !nr_process) {
   2302			printk("nr_process should be between [1; %u]",
   2303					MAX_PROCESSES);
   2304			exit_usage(argv);
   2305		}
   2306	}
   2307
   2308	srand(time(NULL));
   2309	page_size = sysconf(_SC_PAGESIZE);
   2310	if (page_size < 1)
   2311		ksft_exit_skip("sysconf(): %m\n");
   2312
   2313	if (pipe2(test_desc_fd, O_DIRECT) < 0)
   2314		ksft_exit_skip("pipe(): %m\n");
   2315
   2316	if (pipe2(results_fd, O_DIRECT) < 0)
   2317		ksft_exit_skip("pipe(): %m\n");
   2318
   2319	if (init_namespaces())
   2320		ksft_exit_skip("Failed to create namespaces\n");
   2321
   2322	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
   2323		ksft_exit_skip("Failed to open netlink route socket\n");
   2324
   2325	for (i = 0; i < nr_process; i++) {
   2326		char veth[VETH_LEN];
   2327
   2328		snprintf(veth, VETH_LEN, VETH_FMT, i);
   2329
   2330		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
   2331			close(route_sock);
   2332			ksft_exit_fail_msg("Failed to create veth device");
   2333		}
   2334
   2335		if (start_child(i, veth, test_desc_fd)) {
   2336			close(route_sock);
   2337			ksft_exit_fail_msg("Child %u failed to start", i);
   2338		}
   2339	}
   2340
   2341	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
   2342		ksft_exit_fail_msg("close(): %m");
   2343
   2344	ksft_set_plan(proto_plan + compat_plan);
   2345
   2346	if (write_test_plan(test_desc_fd[1]))
   2347		ksft_exit_fail_msg("Failed to write test plan to pipe");
   2348
   2349	ret = check_results();
   2350
   2351	if (children_cleanup() == KSFT_FAIL)
   2352		exit(KSFT_FAIL);
   2353
   2354	exit(ret);
   2355}