cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

noise.c (27778B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
      4 */
      5
      6#include "noise.h"
      7#include "device.h"
      8#include "peer.h"
      9#include "messages.h"
     10#include "queueing.h"
     11#include "peerlookup.h"
     12
     13#include <linux/rcupdate.h>
     14#include <linux/slab.h>
     15#include <linux/bitmap.h>
     16#include <linux/scatterlist.h>
     17#include <linux/highmem.h>
     18#include <crypto/algapi.h>
     19
     20/* This implements Noise_IKpsk2:
     21 *
     22 * <- s
     23 * ******
     24 * -> e, es, s, ss, {t}
     25 * <- e, ee, se, psk, {}
     26 */
     27
     28static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
     29static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
     30static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
     31static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
     32static atomic64_t keypair_counter = ATOMIC64_INIT(0);
     33
     34void __init wg_noise_init(void)
     35{
     36	struct blake2s_state blake;
     37
     38	blake2s(handshake_init_chaining_key, handshake_name, NULL,
     39		NOISE_HASH_LEN, sizeof(handshake_name), 0);
     40	blake2s_init(&blake, NOISE_HASH_LEN);
     41	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
     42	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
     43	blake2s_final(&blake, handshake_init_hash);
     44}
     45
     46/* Must hold peer->handshake.static_identity->lock */
     47void wg_noise_precompute_static_static(struct wg_peer *peer)
     48{
     49	down_write(&peer->handshake.lock);
     50	if (!peer->handshake.static_identity->has_identity ||
     51	    !curve25519(peer->handshake.precomputed_static_static,
     52			peer->handshake.static_identity->static_private,
     53			peer->handshake.remote_static))
     54		memset(peer->handshake.precomputed_static_static, 0,
     55		       NOISE_PUBLIC_KEY_LEN);
     56	up_write(&peer->handshake.lock);
     57}
     58
     59void wg_noise_handshake_init(struct noise_handshake *handshake,
     60			     struct noise_static_identity *static_identity,
     61			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
     62			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
     63			     struct wg_peer *peer)
     64{
     65	memset(handshake, 0, sizeof(*handshake));
     66	init_rwsem(&handshake->lock);
     67	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
     68	handshake->entry.peer = peer;
     69	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
     70	if (peer_preshared_key)
     71		memcpy(handshake->preshared_key, peer_preshared_key,
     72		       NOISE_SYMMETRIC_KEY_LEN);
     73	handshake->static_identity = static_identity;
     74	handshake->state = HANDSHAKE_ZEROED;
     75	wg_noise_precompute_static_static(peer);
     76}
     77
     78static void handshake_zero(struct noise_handshake *handshake)
     79{
     80	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
     81	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
     82	memset(&handshake->hash, 0, NOISE_HASH_LEN);
     83	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
     84	handshake->remote_index = 0;
     85	handshake->state = HANDSHAKE_ZEROED;
     86}
     87
     88void wg_noise_handshake_clear(struct noise_handshake *handshake)
     89{
     90	down_write(&handshake->lock);
     91	wg_index_hashtable_remove(
     92			handshake->entry.peer->device->index_hashtable,
     93			&handshake->entry);
     94	handshake_zero(handshake);
     95	up_write(&handshake->lock);
     96}
     97
     98static struct noise_keypair *keypair_create(struct wg_peer *peer)
     99{
    100	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
    101
    102	if (unlikely(!keypair))
    103		return NULL;
    104	spin_lock_init(&keypair->receiving_counter.lock);
    105	keypair->internal_id = atomic64_inc_return(&keypair_counter);
    106	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
    107	keypair->entry.peer = peer;
    108	kref_init(&keypair->refcount);
    109	return keypair;
    110}
    111
    112static void keypair_free_rcu(struct rcu_head *rcu)
    113{
    114	kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
    115}
    116
    117static void keypair_free_kref(struct kref *kref)
    118{
    119	struct noise_keypair *keypair =
    120		container_of(kref, struct noise_keypair, refcount);
    121
    122	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
    123			    keypair->entry.peer->device->dev->name,
    124			    keypair->internal_id,
    125			    keypair->entry.peer->internal_id);
    126	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
    127				  &keypair->entry);
    128	call_rcu(&keypair->rcu, keypair_free_rcu);
    129}
    130
    131void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
    132{
    133	if (unlikely(!keypair))
    134		return;
    135	if (unlikely(unreference_now))
    136		wg_index_hashtable_remove(
    137			keypair->entry.peer->device->index_hashtable,
    138			&keypair->entry);
    139	kref_put(&keypair->refcount, keypair_free_kref);
    140}
    141
    142struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
    143{
    144	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
    145		"Taking noise keypair reference without holding the RCU BH read lock");
    146	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
    147		return NULL;
    148	return keypair;
    149}
    150
    151void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
    152{
    153	struct noise_keypair *old;
    154
    155	spin_lock_bh(&keypairs->keypair_update_lock);
    156
    157	/* We zero the next_keypair before zeroing the others, so that
    158	 * wg_noise_received_with_keypair returns early before subsequent ones
    159	 * are zeroed.
    160	 */
    161	old = rcu_dereference_protected(keypairs->next_keypair,
    162		lockdep_is_held(&keypairs->keypair_update_lock));
    163	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
    164	wg_noise_keypair_put(old, true);
    165
    166	old = rcu_dereference_protected(keypairs->previous_keypair,
    167		lockdep_is_held(&keypairs->keypair_update_lock));
    168	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
    169	wg_noise_keypair_put(old, true);
    170
    171	old = rcu_dereference_protected(keypairs->current_keypair,
    172		lockdep_is_held(&keypairs->keypair_update_lock));
    173	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
    174	wg_noise_keypair_put(old, true);
    175
    176	spin_unlock_bh(&keypairs->keypair_update_lock);
    177}
    178
    179void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
    180{
    181	struct noise_keypair *keypair;
    182
    183	wg_noise_handshake_clear(&peer->handshake);
    184	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
    185
    186	spin_lock_bh(&peer->keypairs.keypair_update_lock);
    187	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
    188			lockdep_is_held(&peer->keypairs.keypair_update_lock));
    189	if (keypair)
    190		keypair->sending.is_valid = false;
    191	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
    192			lockdep_is_held(&peer->keypairs.keypair_update_lock));
    193	if (keypair)
    194		keypair->sending.is_valid = false;
    195	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
    196}
    197
    198static void add_new_keypair(struct noise_keypairs *keypairs,
    199			    struct noise_keypair *new_keypair)
    200{
    201	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
    202
    203	spin_lock_bh(&keypairs->keypair_update_lock);
    204	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
    205		lockdep_is_held(&keypairs->keypair_update_lock));
    206	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
    207		lockdep_is_held(&keypairs->keypair_update_lock));
    208	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
    209		lockdep_is_held(&keypairs->keypair_update_lock));
    210	if (new_keypair->i_am_the_initiator) {
    211		/* If we're the initiator, it means we've sent a handshake, and
    212		 * received a confirmation response, which means this new
    213		 * keypair can now be used.
    214		 */
    215		if (next_keypair) {
    216			/* If there already was a next keypair pending, we
    217			 * demote it to be the previous keypair, and free the
    218			 * existing current. Note that this means KCI can result
    219			 * in this transition. It would perhaps be more sound to
    220			 * always just get rid of the unused next keypair
    221			 * instead of putting it in the previous slot, but this
    222			 * might be a bit less robust. Something to think about
    223			 * for the future.
    224			 */
    225			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
    226			rcu_assign_pointer(keypairs->previous_keypair,
    227					   next_keypair);
    228			wg_noise_keypair_put(current_keypair, true);
    229		} else /* If there wasn't an existing next keypair, we replace
    230			* the previous with the current one.
    231			*/
    232			rcu_assign_pointer(keypairs->previous_keypair,
    233					   current_keypair);
    234		/* At this point we can get rid of the old previous keypair, and
    235		 * set up the new keypair.
    236		 */
    237		wg_noise_keypair_put(previous_keypair, true);
    238		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
    239	} else {
    240		/* If we're the responder, it means we can't use the new keypair
    241		 * until we receive confirmation via the first data packet, so
    242		 * we get rid of the existing previous one, the possibly
    243		 * existing next one, and slide in the new next one.
    244		 */
    245		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
    246		wg_noise_keypair_put(next_keypair, true);
    247		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
    248		wg_noise_keypair_put(previous_keypair, true);
    249	}
    250	spin_unlock_bh(&keypairs->keypair_update_lock);
    251}
    252
    253bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
    254				    struct noise_keypair *received_keypair)
    255{
    256	struct noise_keypair *old_keypair;
    257	bool key_is_new;
    258
    259	/* We first check without taking the spinlock. */
    260	key_is_new = received_keypair ==
    261		     rcu_access_pointer(keypairs->next_keypair);
    262	if (likely(!key_is_new))
    263		return false;
    264
    265	spin_lock_bh(&keypairs->keypair_update_lock);
    266	/* After locking, we double check that things didn't change from
    267	 * beneath us.
    268	 */
    269	if (unlikely(received_keypair !=
    270		    rcu_dereference_protected(keypairs->next_keypair,
    271			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
    272		spin_unlock_bh(&keypairs->keypair_update_lock);
    273		return false;
    274	}
    275
    276	/* When we've finally received the confirmation, we slide the next
    277	 * into the current, the current into the previous, and get rid of
    278	 * the old previous.
    279	 */
    280	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
    281		lockdep_is_held(&keypairs->keypair_update_lock));
    282	rcu_assign_pointer(keypairs->previous_keypair,
    283		rcu_dereference_protected(keypairs->current_keypair,
    284			lockdep_is_held(&keypairs->keypair_update_lock)));
    285	wg_noise_keypair_put(old_keypair, true);
    286	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
    287	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
    288
    289	spin_unlock_bh(&keypairs->keypair_update_lock);
    290	return true;
    291}
    292
    293/* Must hold static_identity->lock */
    294void wg_noise_set_static_identity_private_key(
    295	struct noise_static_identity *static_identity,
    296	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
    297{
    298	memcpy(static_identity->static_private, private_key,
    299	       NOISE_PUBLIC_KEY_LEN);
    300	curve25519_clamp_secret(static_identity->static_private);
    301	static_identity->has_identity = curve25519_generate_public(
    302		static_identity->static_public, private_key);
    303}
    304
    305static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
    306{
    307	struct blake2s_state state;
    308	u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
    309	u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
    310	int i;
    311
    312	if (keylen > BLAKE2S_BLOCK_SIZE) {
    313		blake2s_init(&state, BLAKE2S_HASH_SIZE);
    314		blake2s_update(&state, key, keylen);
    315		blake2s_final(&state, x_key);
    316	} else
    317		memcpy(x_key, key, keylen);
    318
    319	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
    320		x_key[i] ^= 0x36;
    321
    322	blake2s_init(&state, BLAKE2S_HASH_SIZE);
    323	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
    324	blake2s_update(&state, in, inlen);
    325	blake2s_final(&state, i_hash);
    326
    327	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
    328		x_key[i] ^= 0x5c ^ 0x36;
    329
    330	blake2s_init(&state, BLAKE2S_HASH_SIZE);
    331	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
    332	blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
    333	blake2s_final(&state, i_hash);
    334
    335	memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
    336	memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
    337	memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
    338}
    339
    340/* This is Hugo Krawczyk's HKDF:
    341 *  - https://eprint.iacr.org/2010/264.pdf
    342 *  - https://tools.ietf.org/html/rfc5869
    343 */
    344static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
    345		size_t first_len, size_t second_len, size_t third_len,
    346		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
    347{
    348	u8 output[BLAKE2S_HASH_SIZE + 1];
    349	u8 secret[BLAKE2S_HASH_SIZE];
    350
    351	WARN_ON(IS_ENABLED(DEBUG) &&
    352		(first_len > BLAKE2S_HASH_SIZE ||
    353		 second_len > BLAKE2S_HASH_SIZE ||
    354		 third_len > BLAKE2S_HASH_SIZE ||
    355		 ((second_len || second_dst || third_len || third_dst) &&
    356		  (!first_len || !first_dst)) ||
    357		 ((third_len || third_dst) && (!second_len || !second_dst))));
    358
    359	/* Extract entropy from data into secret */
    360	hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
    361
    362	if (!first_dst || !first_len)
    363		goto out;
    364
    365	/* Expand first key: key = secret, data = 0x1 */
    366	output[0] = 1;
    367	hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
    368	memcpy(first_dst, output, first_len);
    369
    370	if (!second_dst || !second_len)
    371		goto out;
    372
    373	/* Expand second key: key = secret, data = first-key || 0x2 */
    374	output[BLAKE2S_HASH_SIZE] = 2;
    375	hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
    376	memcpy(second_dst, output, second_len);
    377
    378	if (!third_dst || !third_len)
    379		goto out;
    380
    381	/* Expand third key: key = secret, data = second-key || 0x3 */
    382	output[BLAKE2S_HASH_SIZE] = 3;
    383	hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
    384	memcpy(third_dst, output, third_len);
    385
    386out:
    387	/* Clear sensitive data from stack */
    388	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
    389	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
    390}
    391
    392static void derive_keys(struct noise_symmetric_key *first_dst,
    393			struct noise_symmetric_key *second_dst,
    394			const u8 chaining_key[NOISE_HASH_LEN])
    395{
    396	u64 birthdate = ktime_get_coarse_boottime_ns();
    397	kdf(first_dst->key, second_dst->key, NULL, NULL,
    398	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
    399	    chaining_key);
    400	first_dst->birthdate = second_dst->birthdate = birthdate;
    401	first_dst->is_valid = second_dst->is_valid = true;
    402}
    403
    404static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
    405				u8 key[NOISE_SYMMETRIC_KEY_LEN],
    406				const u8 private[NOISE_PUBLIC_KEY_LEN],
    407				const u8 public[NOISE_PUBLIC_KEY_LEN])
    408{
    409	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
    410
    411	if (unlikely(!curve25519(dh_calculation, private, public)))
    412		return false;
    413	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
    414	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
    415	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
    416	return true;
    417}
    418
    419static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
    420					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
    421					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
    422{
    423	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
    424	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
    425		return false;
    426	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
    427	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
    428	    chaining_key);
    429	return true;
    430}
    431
    432static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
    433{
    434	struct blake2s_state blake;
    435
    436	blake2s_init(&blake, NOISE_HASH_LEN);
    437	blake2s_update(&blake, hash, NOISE_HASH_LEN);
    438	blake2s_update(&blake, src, src_len);
    439	blake2s_final(&blake, hash);
    440}
    441
    442static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
    443		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
    444		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
    445{
    446	u8 temp_hash[NOISE_HASH_LEN];
    447
    448	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
    449	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
    450	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
    451	memzero_explicit(temp_hash, NOISE_HASH_LEN);
    452}
    453
    454static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
    455			   u8 hash[NOISE_HASH_LEN],
    456			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
    457{
    458	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
    459	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
    460	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
    461}
    462
    463static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
    464			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
    465			    u8 hash[NOISE_HASH_LEN])
    466{
    467	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
    468				 NOISE_HASH_LEN,
    469				 0 /* Always zero for Noise_IK */, key);
    470	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
    471}
    472
    473static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
    474			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
    475			    u8 hash[NOISE_HASH_LEN])
    476{
    477	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
    478				      hash, NOISE_HASH_LEN,
    479				      0 /* Always zero for Noise_IK */, key))
    480		return false;
    481	mix_hash(hash, src_ciphertext, src_len);
    482	return true;
    483}
    484
    485static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
    486			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
    487			      u8 chaining_key[NOISE_HASH_LEN],
    488			      u8 hash[NOISE_HASH_LEN])
    489{
    490	if (ephemeral_dst != ephemeral_src)
    491		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
    492	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
    493	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
    494	    NOISE_PUBLIC_KEY_LEN, chaining_key);
    495}
    496
    497static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
    498{
    499	struct timespec64 now;
    500
    501	ktime_get_real_ts64(&now);
    502
    503	/* In order to prevent some sort of infoleak from precise timers, we
    504	 * round down the nanoseconds part to the closest rounded-down power of
    505	 * two to the maximum initiations per second allowed anyway by the
    506	 * implementation.
    507	 */
    508	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
    509		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
    510
    511	/* https://cr.yp.to/libtai/tai64.html */
    512	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
    513	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
    514}
    515
    516bool
    517wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
    518				     struct noise_handshake *handshake)
    519{
    520	u8 timestamp[NOISE_TIMESTAMP_LEN];
    521	u8 key[NOISE_SYMMETRIC_KEY_LEN];
    522	bool ret = false;
    523
    524	/* We need to wait for crng _before_ taking any locks, since
    525	 * curve25519_generate_secret uses get_random_bytes_wait.
    526	 */
    527	wait_for_random_bytes();
    528
    529	down_read(&handshake->static_identity->lock);
    530	down_write(&handshake->lock);
    531
    532	if (unlikely(!handshake->static_identity->has_identity))
    533		goto out;
    534
    535	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
    536
    537	handshake_init(handshake->chaining_key, handshake->hash,
    538		       handshake->remote_static);
    539
    540	/* e */
    541	curve25519_generate_secret(handshake->ephemeral_private);
    542	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
    543					handshake->ephemeral_private))
    544		goto out;
    545	message_ephemeral(dst->unencrypted_ephemeral,
    546			  dst->unencrypted_ephemeral, handshake->chaining_key,
    547			  handshake->hash);
    548
    549	/* es */
    550	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
    551		    handshake->remote_static))
    552		goto out;
    553
    554	/* s */
    555	message_encrypt(dst->encrypted_static,
    556			handshake->static_identity->static_public,
    557			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
    558
    559	/* ss */
    560	if (!mix_precomputed_dh(handshake->chaining_key, key,
    561				handshake->precomputed_static_static))
    562		goto out;
    563
    564	/* {t} */
    565	tai64n_now(timestamp);
    566	message_encrypt(dst->encrypted_timestamp, timestamp,
    567			NOISE_TIMESTAMP_LEN, key, handshake->hash);
    568
    569	dst->sender_index = wg_index_hashtable_insert(
    570		handshake->entry.peer->device->index_hashtable,
    571		&handshake->entry);
    572
    573	handshake->state = HANDSHAKE_CREATED_INITIATION;
    574	ret = true;
    575
    576out:
    577	up_write(&handshake->lock);
    578	up_read(&handshake->static_identity->lock);
    579	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
    580	return ret;
    581}
    582
    583struct wg_peer *
    584wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
    585				      struct wg_device *wg)
    586{
    587	struct wg_peer *peer = NULL, *ret_peer = NULL;
    588	struct noise_handshake *handshake;
    589	bool replay_attack, flood_attack;
    590	u8 key[NOISE_SYMMETRIC_KEY_LEN];
    591	u8 chaining_key[NOISE_HASH_LEN];
    592	u8 hash[NOISE_HASH_LEN];
    593	u8 s[NOISE_PUBLIC_KEY_LEN];
    594	u8 e[NOISE_PUBLIC_KEY_LEN];
    595	u8 t[NOISE_TIMESTAMP_LEN];
    596	u64 initiation_consumption;
    597
    598	down_read(&wg->static_identity.lock);
    599	if (unlikely(!wg->static_identity.has_identity))
    600		goto out;
    601
    602	handshake_init(chaining_key, hash, wg->static_identity.static_public);
    603
    604	/* e */
    605	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
    606
    607	/* es */
    608	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
    609		goto out;
    610
    611	/* s */
    612	if (!message_decrypt(s, src->encrypted_static,
    613			     sizeof(src->encrypted_static), key, hash))
    614		goto out;
    615
    616	/* Lookup which peer we're actually talking to */
    617	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
    618	if (!peer)
    619		goto out;
    620	handshake = &peer->handshake;
    621
    622	/* ss */
    623	if (!mix_precomputed_dh(chaining_key, key,
    624				handshake->precomputed_static_static))
    625	    goto out;
    626
    627	/* {t} */
    628	if (!message_decrypt(t, src->encrypted_timestamp,
    629			     sizeof(src->encrypted_timestamp), key, hash))
    630		goto out;
    631
    632	down_read(&handshake->lock);
    633	replay_attack = memcmp(t, handshake->latest_timestamp,
    634			       NOISE_TIMESTAMP_LEN) <= 0;
    635	flood_attack = (s64)handshake->last_initiation_consumption +
    636			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
    637		       (s64)ktime_get_coarse_boottime_ns();
    638	up_read(&handshake->lock);
    639	if (replay_attack || flood_attack)
    640		goto out;
    641
    642	/* Success! Copy everything to peer */
    643	down_write(&handshake->lock);
    644	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
    645	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
    646		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
    647	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
    648	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
    649	handshake->remote_index = src->sender_index;
    650	initiation_consumption = ktime_get_coarse_boottime_ns();
    651	if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
    652		handshake->last_initiation_consumption = initiation_consumption;
    653	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
    654	up_write(&handshake->lock);
    655	ret_peer = peer;
    656
    657out:
    658	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
    659	memzero_explicit(hash, NOISE_HASH_LEN);
    660	memzero_explicit(chaining_key, NOISE_HASH_LEN);
    661	up_read(&wg->static_identity.lock);
    662	if (!ret_peer)
    663		wg_peer_put(peer);
    664	return ret_peer;
    665}
    666
    667bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
    668					struct noise_handshake *handshake)
    669{
    670	u8 key[NOISE_SYMMETRIC_KEY_LEN];
    671	bool ret = false;
    672
    673	/* We need to wait for crng _before_ taking any locks, since
    674	 * curve25519_generate_secret uses get_random_bytes_wait.
    675	 */
    676	wait_for_random_bytes();
    677
    678	down_read(&handshake->static_identity->lock);
    679	down_write(&handshake->lock);
    680
    681	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
    682		goto out;
    683
    684	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
    685	dst->receiver_index = handshake->remote_index;
    686
    687	/* e */
    688	curve25519_generate_secret(handshake->ephemeral_private);
    689	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
    690					handshake->ephemeral_private))
    691		goto out;
    692	message_ephemeral(dst->unencrypted_ephemeral,
    693			  dst->unencrypted_ephemeral, handshake->chaining_key,
    694			  handshake->hash);
    695
    696	/* ee */
    697	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
    698		    handshake->remote_ephemeral))
    699		goto out;
    700
    701	/* se */
    702	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
    703		    handshake->remote_static))
    704		goto out;
    705
    706	/* psk */
    707	mix_psk(handshake->chaining_key, handshake->hash, key,
    708		handshake->preshared_key);
    709
    710	/* {} */
    711	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
    712
    713	dst->sender_index = wg_index_hashtable_insert(
    714		handshake->entry.peer->device->index_hashtable,
    715		&handshake->entry);
    716
    717	handshake->state = HANDSHAKE_CREATED_RESPONSE;
    718	ret = true;
    719
    720out:
    721	up_write(&handshake->lock);
    722	up_read(&handshake->static_identity->lock);
    723	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
    724	return ret;
    725}
    726
    727struct wg_peer *
    728wg_noise_handshake_consume_response(struct message_handshake_response *src,
    729				    struct wg_device *wg)
    730{
    731	enum noise_handshake_state state = HANDSHAKE_ZEROED;
    732	struct wg_peer *peer = NULL, *ret_peer = NULL;
    733	struct noise_handshake *handshake;
    734	u8 key[NOISE_SYMMETRIC_KEY_LEN];
    735	u8 hash[NOISE_HASH_LEN];
    736	u8 chaining_key[NOISE_HASH_LEN];
    737	u8 e[NOISE_PUBLIC_KEY_LEN];
    738	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
    739	u8 static_private[NOISE_PUBLIC_KEY_LEN];
    740	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
    741
    742	down_read(&wg->static_identity.lock);
    743
    744	if (unlikely(!wg->static_identity.has_identity))
    745		goto out;
    746
    747	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
    748		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
    749		src->receiver_index, &peer);
    750	if (unlikely(!handshake))
    751		goto out;
    752
    753	down_read(&handshake->lock);
    754	state = handshake->state;
    755	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
    756	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
    757	memcpy(ephemeral_private, handshake->ephemeral_private,
    758	       NOISE_PUBLIC_KEY_LEN);
    759	memcpy(preshared_key, handshake->preshared_key,
    760	       NOISE_SYMMETRIC_KEY_LEN);
    761	up_read(&handshake->lock);
    762
    763	if (state != HANDSHAKE_CREATED_INITIATION)
    764		goto fail;
    765
    766	/* e */
    767	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
    768
    769	/* ee */
    770	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
    771		goto fail;
    772
    773	/* se */
    774	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
    775		goto fail;
    776
    777	/* psk */
    778	mix_psk(chaining_key, hash, key, preshared_key);
    779
    780	/* {} */
    781	if (!message_decrypt(NULL, src->encrypted_nothing,
    782			     sizeof(src->encrypted_nothing), key, hash))
    783		goto fail;
    784
    785	/* Success! Copy everything to peer */
    786	down_write(&handshake->lock);
    787	/* It's important to check that the state is still the same, while we
    788	 * have an exclusive lock.
    789	 */
    790	if (handshake->state != state) {
    791		up_write(&handshake->lock);
    792		goto fail;
    793	}
    794	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
    795	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
    796	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
    797	handshake->remote_index = src->sender_index;
    798	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
    799	up_write(&handshake->lock);
    800	ret_peer = peer;
    801	goto out;
    802
    803fail:
    804	wg_peer_put(peer);
    805out:
    806	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
    807	memzero_explicit(hash, NOISE_HASH_LEN);
    808	memzero_explicit(chaining_key, NOISE_HASH_LEN);
    809	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
    810	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
    811	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
    812	up_read(&wg->static_identity.lock);
    813	return ret_peer;
    814}
    815
    816bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
    817				      struct noise_keypairs *keypairs)
    818{
    819	struct noise_keypair *new_keypair;
    820	bool ret = false;
    821
    822	down_write(&handshake->lock);
    823	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
    824	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
    825		goto out;
    826
    827	new_keypair = keypair_create(handshake->entry.peer);
    828	if (!new_keypair)
    829		goto out;
    830	new_keypair->i_am_the_initiator = handshake->state ==
    831					  HANDSHAKE_CONSUMED_RESPONSE;
    832	new_keypair->remote_index = handshake->remote_index;
    833
    834	if (new_keypair->i_am_the_initiator)
    835		derive_keys(&new_keypair->sending, &new_keypair->receiving,
    836			    handshake->chaining_key);
    837	else
    838		derive_keys(&new_keypair->receiving, &new_keypair->sending,
    839			    handshake->chaining_key);
    840
    841	handshake_zero(handshake);
    842	rcu_read_lock_bh();
    843	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
    844					   handshake)->is_dead))) {
    845		add_new_keypair(keypairs, new_keypair);
    846		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
    847				    handshake->entry.peer->device->dev->name,
    848				    new_keypair->internal_id,
    849				    handshake->entry.peer->internal_id);
    850		ret = wg_index_hashtable_replace(
    851			handshake->entry.peer->device->index_hashtable,
    852			&handshake->entry, &new_keypair->entry);
    853	} else {
    854		kfree_sensitive(new_keypair);
    855	}
    856	rcu_read_unlock_bh();
    857
    858out:
    859	up_write(&handshake->lock);
    860	return ret;
    861}