cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

netlink.c (17988B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
      4 */
      5
      6#include "netlink.h"
      7#include "device.h"
      8#include "peer.h"
      9#include "socket.h"
     10#include "queueing.h"
     11#include "messages.h"
     12
     13#include <uapi/linux/wireguard.h>
     14
     15#include <linux/if.h>
     16#include <net/genetlink.h>
     17#include <net/sock.h>
     18#include <crypto/algapi.h>
     19
     20static struct genl_family genl_family;
     21
     22static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
     23	[WGDEVICE_A_IFINDEX]		= { .type = NLA_U32 },
     24	[WGDEVICE_A_IFNAME]		= { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
     25	[WGDEVICE_A_PRIVATE_KEY]	= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
     26	[WGDEVICE_A_PUBLIC_KEY]		= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
     27	[WGDEVICE_A_FLAGS]		= { .type = NLA_U32 },
     28	[WGDEVICE_A_LISTEN_PORT]	= { .type = NLA_U16 },
     29	[WGDEVICE_A_FWMARK]		= { .type = NLA_U32 },
     30	[WGDEVICE_A_PEERS]		= { .type = NLA_NESTED }
     31};
     32
     33static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
     34	[WGPEER_A_PUBLIC_KEY]				= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
     35	[WGPEER_A_PRESHARED_KEY]			= NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
     36	[WGPEER_A_FLAGS]				= { .type = NLA_U32 },
     37	[WGPEER_A_ENDPOINT]				= NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
     38	[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]	= { .type = NLA_U16 },
     39	[WGPEER_A_LAST_HANDSHAKE_TIME]			= NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
     40	[WGPEER_A_RX_BYTES]				= { .type = NLA_U64 },
     41	[WGPEER_A_TX_BYTES]				= { .type = NLA_U64 },
     42	[WGPEER_A_ALLOWEDIPS]				= { .type = NLA_NESTED },
     43	[WGPEER_A_PROTOCOL_VERSION]			= { .type = NLA_U32 }
     44};
     45
     46static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
     47	[WGALLOWEDIP_A_FAMILY]		= { .type = NLA_U16 },
     48	[WGALLOWEDIP_A_IPADDR]		= NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
     49	[WGALLOWEDIP_A_CIDR_MASK]	= { .type = NLA_U8 }
     50};
     51
     52static struct wg_device *lookup_interface(struct nlattr **attrs,
     53					  struct sk_buff *skb)
     54{
     55	struct net_device *dev = NULL;
     56
     57	if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
     58		return ERR_PTR(-EBADR);
     59	if (attrs[WGDEVICE_A_IFINDEX])
     60		dev = dev_get_by_index(sock_net(skb->sk),
     61				       nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
     62	else if (attrs[WGDEVICE_A_IFNAME])
     63		dev = dev_get_by_name(sock_net(skb->sk),
     64				      nla_data(attrs[WGDEVICE_A_IFNAME]));
     65	if (!dev)
     66		return ERR_PTR(-ENODEV);
     67	if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
     68	    strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
     69		dev_put(dev);
     70		return ERR_PTR(-EOPNOTSUPP);
     71	}
     72	return netdev_priv(dev);
     73}
     74
     75static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
     76			  int family)
     77{
     78	struct nlattr *allowedip_nest;
     79
     80	allowedip_nest = nla_nest_start(skb, 0);
     81	if (!allowedip_nest)
     82		return -EMSGSIZE;
     83
     84	if (nla_put_u8(skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
     85	    nla_put_u16(skb, WGALLOWEDIP_A_FAMILY, family) ||
     86	    nla_put(skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
     87		    sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
     88		nla_nest_cancel(skb, allowedip_nest);
     89		return -EMSGSIZE;
     90	}
     91
     92	nla_nest_end(skb, allowedip_nest);
     93	return 0;
     94}
     95
     96struct dump_ctx {
     97	struct wg_device *wg;
     98	struct wg_peer *next_peer;
     99	u64 allowedips_seq;
    100	struct allowedips_node *next_allowedip;
    101};
    102
    103#define DUMP_CTX(cb) ((struct dump_ctx *)(cb)->args)
    104
    105static int
    106get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
    107{
    108
    109	struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
    110	struct allowedips_node *allowedips_node = ctx->next_allowedip;
    111	bool fail;
    112
    113	if (!peer_nest)
    114		return -EMSGSIZE;
    115
    116	down_read(&peer->handshake.lock);
    117	fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
    118		       peer->handshake.remote_static);
    119	up_read(&peer->handshake.lock);
    120	if (fail)
    121		goto err;
    122
    123	if (!allowedips_node) {
    124		const struct __kernel_timespec last_handshake = {
    125			.tv_sec = peer->walltime_last_handshake.tv_sec,
    126			.tv_nsec = peer->walltime_last_handshake.tv_nsec
    127		};
    128
    129		down_read(&peer->handshake.lock);
    130		fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
    131			       NOISE_SYMMETRIC_KEY_LEN,
    132			       peer->handshake.preshared_key);
    133		up_read(&peer->handshake.lock);
    134		if (fail)
    135			goto err;
    136
    137		if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
    138			    sizeof(last_handshake), &last_handshake) ||
    139		    nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
    140				peer->persistent_keepalive_interval) ||
    141		    nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
    142				      WGPEER_A_UNSPEC) ||
    143		    nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
    144				      WGPEER_A_UNSPEC) ||
    145		    nla_put_u32(skb, WGPEER_A_PROTOCOL_VERSION, 1))
    146			goto err;
    147
    148		read_lock_bh(&peer->endpoint_lock);
    149		if (peer->endpoint.addr.sa_family == AF_INET)
    150			fail = nla_put(skb, WGPEER_A_ENDPOINT,
    151				       sizeof(peer->endpoint.addr4),
    152				       &peer->endpoint.addr4);
    153		else if (peer->endpoint.addr.sa_family == AF_INET6)
    154			fail = nla_put(skb, WGPEER_A_ENDPOINT,
    155				       sizeof(peer->endpoint.addr6),
    156				       &peer->endpoint.addr6);
    157		read_unlock_bh(&peer->endpoint_lock);
    158		if (fail)
    159			goto err;
    160		allowedips_node =
    161			list_first_entry_or_null(&peer->allowedips_list,
    162					struct allowedips_node, peer_list);
    163	}
    164	if (!allowedips_node)
    165		goto no_allowedips;
    166	if (!ctx->allowedips_seq)
    167		ctx->allowedips_seq = peer->device->peer_allowedips.seq;
    168	else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
    169		goto no_allowedips;
    170
    171	allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
    172	if (!allowedips_nest)
    173		goto err;
    174
    175	list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
    176				 peer_list) {
    177		u8 cidr, ip[16] __aligned(__alignof(u64));
    178		int family;
    179
    180		family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
    181		if (get_allowedips(skb, ip, cidr, family)) {
    182			nla_nest_end(skb, allowedips_nest);
    183			nla_nest_end(skb, peer_nest);
    184			ctx->next_allowedip = allowedips_node;
    185			return -EMSGSIZE;
    186		}
    187	}
    188	nla_nest_end(skb, allowedips_nest);
    189no_allowedips:
    190	nla_nest_end(skb, peer_nest);
    191	ctx->next_allowedip = NULL;
    192	ctx->allowedips_seq = 0;
    193	return 0;
    194err:
    195	nla_nest_cancel(skb, peer_nest);
    196	return -EMSGSIZE;
    197}
    198
    199static int wg_get_device_start(struct netlink_callback *cb)
    200{
    201	struct wg_device *wg;
    202
    203	wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
    204	if (IS_ERR(wg))
    205		return PTR_ERR(wg);
    206	DUMP_CTX(cb)->wg = wg;
    207	return 0;
    208}
    209
    210static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
    211{
    212	struct wg_peer *peer, *next_peer_cursor;
    213	struct dump_ctx *ctx = DUMP_CTX(cb);
    214	struct wg_device *wg = ctx->wg;
    215	struct nlattr *peers_nest;
    216	int ret = -EMSGSIZE;
    217	bool done = true;
    218	void *hdr;
    219
    220	rtnl_lock();
    221	mutex_lock(&wg->device_update_lock);
    222	cb->seq = wg->device_update_gen;
    223	next_peer_cursor = ctx->next_peer;
    224
    225	hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
    226			  &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
    227	if (!hdr)
    228		goto out;
    229	genl_dump_check_consistent(cb, hdr);
    230
    231	if (!ctx->next_peer) {
    232		if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
    233				wg->incoming_port) ||
    234		    nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
    235		    nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
    236		    nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
    237			goto out;
    238
    239		down_read(&wg->static_identity.lock);
    240		if (wg->static_identity.has_identity) {
    241			if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
    242				    NOISE_PUBLIC_KEY_LEN,
    243				    wg->static_identity.static_private) ||
    244			    nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
    245				    NOISE_PUBLIC_KEY_LEN,
    246				    wg->static_identity.static_public)) {
    247				up_read(&wg->static_identity.lock);
    248				goto out;
    249			}
    250		}
    251		up_read(&wg->static_identity.lock);
    252	}
    253
    254	peers_nest = nla_nest_start(skb, WGDEVICE_A_PEERS);
    255	if (!peers_nest)
    256		goto out;
    257	ret = 0;
    258	/* If the last cursor was removed via list_del_init in peer_remove, then
    259	 * we just treat this the same as there being no more peers left. The
    260	 * reason is that seq_nr should indicate to userspace that this isn't a
    261	 * coherent dump anyway, so they'll try again.
    262	 */
    263	if (list_empty(&wg->peer_list) ||
    264	    (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) {
    265		nla_nest_cancel(skb, peers_nest);
    266		goto out;
    267	}
    268	lockdep_assert_held(&wg->device_update_lock);
    269	peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
    270	list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
    271		if (get_peer(peer, skb, ctx)) {
    272			done = false;
    273			break;
    274		}
    275		next_peer_cursor = peer;
    276	}
    277	nla_nest_end(skb, peers_nest);
    278
    279out:
    280	if (!ret && !done && next_peer_cursor)
    281		wg_peer_get(next_peer_cursor);
    282	wg_peer_put(ctx->next_peer);
    283	mutex_unlock(&wg->device_update_lock);
    284	rtnl_unlock();
    285
    286	if (ret) {
    287		genlmsg_cancel(skb, hdr);
    288		return ret;
    289	}
    290	genlmsg_end(skb, hdr);
    291	if (done) {
    292		ctx->next_peer = NULL;
    293		return 0;
    294	}
    295	ctx->next_peer = next_peer_cursor;
    296	return skb->len;
    297
    298	/* At this point, we can't really deal ourselves with safely zeroing out
    299	 * the private key material after usage. This will need an additional API
    300	 * in the kernel for marking skbs as zero_on_free.
    301	 */
    302}
    303
    304static int wg_get_device_done(struct netlink_callback *cb)
    305{
    306	struct dump_ctx *ctx = DUMP_CTX(cb);
    307
    308	if (ctx->wg)
    309		dev_put(ctx->wg->dev);
    310	wg_peer_put(ctx->next_peer);
    311	return 0;
    312}
    313
    314static int set_port(struct wg_device *wg, u16 port)
    315{
    316	struct wg_peer *peer;
    317
    318	if (wg->incoming_port == port)
    319		return 0;
    320	list_for_each_entry(peer, &wg->peer_list, peer_list)
    321		wg_socket_clear_peer_endpoint_src(peer);
    322	if (!netif_running(wg->dev)) {
    323		wg->incoming_port = port;
    324		return 0;
    325	}
    326	return wg_socket_init(wg, port);
    327}
    328
    329static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
    330{
    331	int ret = -EINVAL;
    332	u16 family;
    333	u8 cidr;
    334
    335	if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
    336	    !attrs[WGALLOWEDIP_A_CIDR_MASK])
    337		return ret;
    338	family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
    339	cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
    340
    341	if (family == AF_INET && cidr <= 32 &&
    342	    nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
    343		ret = wg_allowedips_insert_v4(
    344			&peer->device->peer_allowedips,
    345			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
    346			&peer->device->device_update_lock);
    347	else if (family == AF_INET6 && cidr <= 128 &&
    348		 nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
    349		ret = wg_allowedips_insert_v6(
    350			&peer->device->peer_allowedips,
    351			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
    352			&peer->device->device_update_lock);
    353
    354	return ret;
    355}
    356
    357static int set_peer(struct wg_device *wg, struct nlattr **attrs)
    358{
    359	u8 *public_key = NULL, *preshared_key = NULL;
    360	struct wg_peer *peer = NULL;
    361	u32 flags = 0;
    362	int ret;
    363
    364	ret = -EINVAL;
    365	if (attrs[WGPEER_A_PUBLIC_KEY] &&
    366	    nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
    367		public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
    368	else
    369		goto out;
    370	if (attrs[WGPEER_A_PRESHARED_KEY] &&
    371	    nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
    372		preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
    373
    374	if (attrs[WGPEER_A_FLAGS])
    375		flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
    376	ret = -EOPNOTSUPP;
    377	if (flags & ~__WGPEER_F_ALL)
    378		goto out;
    379
    380	ret = -EPFNOSUPPORT;
    381	if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
    382		if (nla_get_u32(attrs[WGPEER_A_PROTOCOL_VERSION]) != 1)
    383			goto out;
    384	}
    385
    386	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
    387					  nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
    388	ret = 0;
    389	if (!peer) { /* Peer doesn't exist yet. Add a new one. */
    390		if (flags & (WGPEER_F_REMOVE_ME | WGPEER_F_UPDATE_ONLY))
    391			goto out;
    392
    393		/* The peer is new, so there aren't allowed IPs to remove. */
    394		flags &= ~WGPEER_F_REPLACE_ALLOWEDIPS;
    395
    396		down_read(&wg->static_identity.lock);
    397		if (wg->static_identity.has_identity &&
    398		    !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
    399			    wg->static_identity.static_public,
    400			    NOISE_PUBLIC_KEY_LEN)) {
    401			/* We silently ignore peers that have the same public
    402			 * key as the device. The reason we do it silently is
    403			 * that we'd like for people to be able to reuse the
    404			 * same set of API calls across peers.
    405			 */
    406			up_read(&wg->static_identity.lock);
    407			ret = 0;
    408			goto out;
    409		}
    410		up_read(&wg->static_identity.lock);
    411
    412		peer = wg_peer_create(wg, public_key, preshared_key);
    413		if (IS_ERR(peer)) {
    414			ret = PTR_ERR(peer);
    415			peer = NULL;
    416			goto out;
    417		}
    418		/* Take additional reference, as though we've just been
    419		 * looked up.
    420		 */
    421		wg_peer_get(peer);
    422	}
    423
    424	if (flags & WGPEER_F_REMOVE_ME) {
    425		wg_peer_remove(peer);
    426		goto out;
    427	}
    428
    429	if (preshared_key) {
    430		down_write(&peer->handshake.lock);
    431		memcpy(&peer->handshake.preshared_key, preshared_key,
    432		       NOISE_SYMMETRIC_KEY_LEN);
    433		up_write(&peer->handshake.lock);
    434	}
    435
    436	if (attrs[WGPEER_A_ENDPOINT]) {
    437		struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
    438		size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
    439
    440		if ((len == sizeof(struct sockaddr_in) &&
    441		     addr->sa_family == AF_INET) ||
    442		    (len == sizeof(struct sockaddr_in6) &&
    443		     addr->sa_family == AF_INET6)) {
    444			struct endpoint endpoint = { { { 0 } } };
    445
    446			memcpy(&endpoint.addr, addr, len);
    447			wg_socket_set_peer_endpoint(peer, &endpoint);
    448		}
    449	}
    450
    451	if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
    452		wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
    453					     &wg->device_update_lock);
    454
    455	if (attrs[WGPEER_A_ALLOWEDIPS]) {
    456		struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
    457		int rem;
    458
    459		nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
    460			ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
    461					       attr, allowedip_policy, NULL);
    462			if (ret < 0)
    463				goto out;
    464			ret = set_allowedip(peer, allowedip);
    465			if (ret < 0)
    466				goto out;
    467		}
    468	}
    469
    470	if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
    471		const u16 persistent_keepalive_interval = nla_get_u16(
    472				attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
    473		const bool send_keepalive =
    474			!peer->persistent_keepalive_interval &&
    475			persistent_keepalive_interval &&
    476			netif_running(wg->dev);
    477
    478		peer->persistent_keepalive_interval = persistent_keepalive_interval;
    479		if (send_keepalive)
    480			wg_packet_send_keepalive(peer);
    481	}
    482
    483	if (netif_running(wg->dev))
    484		wg_packet_send_staged_packets(peer);
    485
    486out:
    487	wg_peer_put(peer);
    488	if (attrs[WGPEER_A_PRESHARED_KEY])
    489		memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
    490				 nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
    491	return ret;
    492}
    493
    494static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
    495{
    496	struct wg_device *wg = lookup_interface(info->attrs, skb);
    497	u32 flags = 0;
    498	int ret;
    499
    500	if (IS_ERR(wg)) {
    501		ret = PTR_ERR(wg);
    502		goto out_nodev;
    503	}
    504
    505	rtnl_lock();
    506	mutex_lock(&wg->device_update_lock);
    507
    508	if (info->attrs[WGDEVICE_A_FLAGS])
    509		flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
    510	ret = -EOPNOTSUPP;
    511	if (flags & ~__WGDEVICE_F_ALL)
    512		goto out;
    513
    514	if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
    515		struct net *net;
    516		rcu_read_lock();
    517		net = rcu_dereference(wg->creating_net);
    518		ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
    519		rcu_read_unlock();
    520		if (ret)
    521			goto out;
    522	}
    523
    524	++wg->device_update_gen;
    525
    526	if (info->attrs[WGDEVICE_A_FWMARK]) {
    527		struct wg_peer *peer;
    528
    529		wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
    530		list_for_each_entry(peer, &wg->peer_list, peer_list)
    531			wg_socket_clear_peer_endpoint_src(peer);
    532	}
    533
    534	if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
    535		ret = set_port(wg,
    536			nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
    537		if (ret)
    538			goto out;
    539	}
    540
    541	if (flags & WGDEVICE_F_REPLACE_PEERS)
    542		wg_peer_remove_all(wg);
    543
    544	if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
    545	    nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
    546		    NOISE_PUBLIC_KEY_LEN) {
    547		u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
    548		u8 public_key[NOISE_PUBLIC_KEY_LEN];
    549		struct wg_peer *peer, *temp;
    550
    551		if (!crypto_memneq(wg->static_identity.static_private,
    552				   private_key, NOISE_PUBLIC_KEY_LEN))
    553			goto skip_set_private_key;
    554
    555		/* We remove before setting, to prevent race, which means doing
    556		 * two 25519-genpub ops.
    557		 */
    558		if (curve25519_generate_public(public_key, private_key)) {
    559			peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
    560							  public_key);
    561			if (peer) {
    562				wg_peer_put(peer);
    563				wg_peer_remove(peer);
    564			}
    565		}
    566
    567		down_write(&wg->static_identity.lock);
    568		wg_noise_set_static_identity_private_key(&wg->static_identity,
    569							 private_key);
    570		list_for_each_entry_safe(peer, temp, &wg->peer_list,
    571					 peer_list) {
    572			wg_noise_precompute_static_static(peer);
    573			wg_noise_expire_current_peer_keypairs(peer);
    574		}
    575		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
    576		up_write(&wg->static_identity.lock);
    577	}
    578skip_set_private_key:
    579
    580	if (info->attrs[WGDEVICE_A_PEERS]) {
    581		struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
    582		int rem;
    583
    584		nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
    585			ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
    586					       peer_policy, NULL);
    587			if (ret < 0)
    588				goto out;
    589			ret = set_peer(wg, peer);
    590			if (ret < 0)
    591				goto out;
    592		}
    593	}
    594	ret = 0;
    595
    596out:
    597	mutex_unlock(&wg->device_update_lock);
    598	rtnl_unlock();
    599	dev_put(wg->dev);
    600out_nodev:
    601	if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
    602		memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
    603				 nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
    604	return ret;
    605}
    606
    607static const struct genl_ops genl_ops[] = {
    608	{
    609		.cmd = WG_CMD_GET_DEVICE,
    610		.start = wg_get_device_start,
    611		.dumpit = wg_get_device_dump,
    612		.done = wg_get_device_done,
    613		.flags = GENL_UNS_ADMIN_PERM
    614	}, {
    615		.cmd = WG_CMD_SET_DEVICE,
    616		.doit = wg_set_device,
    617		.flags = GENL_UNS_ADMIN_PERM
    618	}
    619};
    620
    621static struct genl_family genl_family __ro_after_init = {
    622	.ops = genl_ops,
    623	.n_ops = ARRAY_SIZE(genl_ops),
    624	.name = WG_GENL_NAME,
    625	.version = WG_GENL_VERSION,
    626	.maxattr = WGDEVICE_A_MAX,
    627	.module = THIS_MODULE,
    628	.policy = device_policy,
    629	.netnsok = true
    630};
    631
    632int __init wg_genetlink_init(void)
    633{
    634	return genl_register_family(&genl_family);
    635}
    636
    637void __exit wg_genetlink_uninit(void)
    638{
    639	genl_unregister_family(&genl_family);
    640}