cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

aes-glue.c (27155B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
      4 *
      5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
      6 */
      7
      8#include <asm/neon.h>
      9#include <asm/hwcap.h>
     10#include <asm/simd.h>
     11#include <crypto/aes.h>
     12#include <crypto/ctr.h>
     13#include <crypto/sha2.h>
     14#include <crypto/internal/hash.h>
     15#include <crypto/internal/simd.h>
     16#include <crypto/internal/skcipher.h>
     17#include <crypto/scatterwalk.h>
     18#include <linux/module.h>
     19#include <linux/cpufeature.h>
     20#include <crypto/xts.h>
     21
     22#include "aes-ce-setkey.h"
     23
     24#ifdef USE_V8_CRYPTO_EXTENSIONS
     25#define MODE			"ce"
     26#define PRIO			300
     27#define aes_expandkey		ce_aes_expandkey
     28#define aes_ecb_encrypt		ce_aes_ecb_encrypt
     29#define aes_ecb_decrypt		ce_aes_ecb_decrypt
     30#define aes_cbc_encrypt		ce_aes_cbc_encrypt
     31#define aes_cbc_decrypt		ce_aes_cbc_decrypt
     32#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
     33#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
     34#define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
     35#define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
     36#define aes_ctr_encrypt		ce_aes_ctr_encrypt
     37#define aes_xts_encrypt		ce_aes_xts_encrypt
     38#define aes_xts_decrypt		ce_aes_xts_decrypt
     39#define aes_mac_update		ce_aes_mac_update
     40MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
     41#else
     42#define MODE			"neon"
     43#define PRIO			200
     44#define aes_ecb_encrypt		neon_aes_ecb_encrypt
     45#define aes_ecb_decrypt		neon_aes_ecb_decrypt
     46#define aes_cbc_encrypt		neon_aes_cbc_encrypt
     47#define aes_cbc_decrypt		neon_aes_cbc_decrypt
     48#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
     49#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
     50#define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
     51#define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
     52#define aes_ctr_encrypt		neon_aes_ctr_encrypt
     53#define aes_xts_encrypt		neon_aes_xts_encrypt
     54#define aes_xts_decrypt		neon_aes_xts_decrypt
     55#define aes_mac_update		neon_aes_mac_update
     56MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
     57#endif
     58#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
     59MODULE_ALIAS_CRYPTO("ecb(aes)");
     60MODULE_ALIAS_CRYPTO("cbc(aes)");
     61MODULE_ALIAS_CRYPTO("ctr(aes)");
     62MODULE_ALIAS_CRYPTO("xts(aes)");
     63#endif
     64MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
     65MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
     66MODULE_ALIAS_CRYPTO("cmac(aes)");
     67MODULE_ALIAS_CRYPTO("xcbc(aes)");
     68MODULE_ALIAS_CRYPTO("cbcmac(aes)");
     69
     70MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
     71MODULE_LICENSE("GPL v2");
     72
     73/* defined in aes-modes.S */
     74asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
     75				int rounds, int blocks);
     76asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
     77				int rounds, int blocks);
     78
     79asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
     80				int rounds, int blocks, u8 iv[]);
     81asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
     82				int rounds, int blocks, u8 iv[]);
     83
     84asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
     85				int rounds, int bytes, u8 const iv[]);
     86asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
     87				int rounds, int bytes, u8 const iv[]);
     88
     89asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
     90				int rounds, int bytes, u8 ctr[]);
     91
     92asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
     93				int rounds, int bytes, u32 const rk2[], u8 iv[],
     94				int first);
     95asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
     96				int rounds, int bytes, u32 const rk2[], u8 iv[],
     97				int first);
     98
     99asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
    100				      int rounds, int blocks, u8 iv[],
    101				      u32 const rk2[]);
    102asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
    103				      int rounds, int blocks, u8 iv[],
    104				      u32 const rk2[]);
    105
    106asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
    107			      int blocks, u8 dg[], int enc_before,
    108			      int enc_after);
    109
    110struct crypto_aes_xts_ctx {
    111	struct crypto_aes_ctx key1;
    112	struct crypto_aes_ctx __aligned(8) key2;
    113};
    114
    115struct crypto_aes_essiv_cbc_ctx {
    116	struct crypto_aes_ctx key1;
    117	struct crypto_aes_ctx __aligned(8) key2;
    118	struct crypto_shash *hash;
    119};
    120
    121struct mac_tfm_ctx {
    122	struct crypto_aes_ctx key;
    123	u8 __aligned(8) consts[];
    124};
    125
    126struct mac_desc_ctx {
    127	unsigned int len;
    128	u8 dg[AES_BLOCK_SIZE];
    129};
    130
    131static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    132			       unsigned int key_len)
    133{
    134	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    135
    136	return aes_expandkey(ctx, in_key, key_len);
    137}
    138
    139static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
    140				      const u8 *in_key, unsigned int key_len)
    141{
    142	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    143	int ret;
    144
    145	ret = xts_verify_key(tfm, in_key, key_len);
    146	if (ret)
    147		return ret;
    148
    149	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
    150	if (!ret)
    151		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
    152				    key_len / 2);
    153	return ret;
    154}
    155
    156static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
    157					    const u8 *in_key,
    158					    unsigned int key_len)
    159{
    160	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    161	u8 digest[SHA256_DIGEST_SIZE];
    162	int ret;
    163
    164	ret = aes_expandkey(&ctx->key1, in_key, key_len);
    165	if (ret)
    166		return ret;
    167
    168	crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
    169
    170	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
    171}
    172
    173static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
    174{
    175	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    176	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    177	int err, rounds = 6 + ctx->key_length / 4;
    178	struct skcipher_walk walk;
    179	unsigned int blocks;
    180
    181	err = skcipher_walk_virt(&walk, req, false);
    182
    183	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
    184		kernel_neon_begin();
    185		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    186				ctx->key_enc, rounds, blocks);
    187		kernel_neon_end();
    188		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    189	}
    190	return err;
    191}
    192
    193static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
    194{
    195	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    196	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    197	int err, rounds = 6 + ctx->key_length / 4;
    198	struct skcipher_walk walk;
    199	unsigned int blocks;
    200
    201	err = skcipher_walk_virt(&walk, req, false);
    202
    203	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
    204		kernel_neon_begin();
    205		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    206				ctx->key_dec, rounds, blocks);
    207		kernel_neon_end();
    208		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    209	}
    210	return err;
    211}
    212
    213static int cbc_encrypt_walk(struct skcipher_request *req,
    214			    struct skcipher_walk *walk)
    215{
    216	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    217	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    218	int err = 0, rounds = 6 + ctx->key_length / 4;
    219	unsigned int blocks;
    220
    221	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
    222		kernel_neon_begin();
    223		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
    224				ctx->key_enc, rounds, blocks, walk->iv);
    225		kernel_neon_end();
    226		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
    227	}
    228	return err;
    229}
    230
    231static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
    232{
    233	struct skcipher_walk walk;
    234	int err;
    235
    236	err = skcipher_walk_virt(&walk, req, false);
    237	if (err)
    238		return err;
    239	return cbc_encrypt_walk(req, &walk);
    240}
    241
    242static int cbc_decrypt_walk(struct skcipher_request *req,
    243			    struct skcipher_walk *walk)
    244{
    245	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    246	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    247	int err = 0, rounds = 6 + ctx->key_length / 4;
    248	unsigned int blocks;
    249
    250	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
    251		kernel_neon_begin();
    252		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
    253				ctx->key_dec, rounds, blocks, walk->iv);
    254		kernel_neon_end();
    255		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
    256	}
    257	return err;
    258}
    259
    260static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
    261{
    262	struct skcipher_walk walk;
    263	int err;
    264
    265	err = skcipher_walk_virt(&walk, req, false);
    266	if (err)
    267		return err;
    268	return cbc_decrypt_walk(req, &walk);
    269}
    270
    271static int cts_cbc_encrypt(struct skcipher_request *req)
    272{
    273	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    274	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    275	int err, rounds = 6 + ctx->key_length / 4;
    276	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
    277	struct scatterlist *src = req->src, *dst = req->dst;
    278	struct scatterlist sg_src[2], sg_dst[2];
    279	struct skcipher_request subreq;
    280	struct skcipher_walk walk;
    281
    282	skcipher_request_set_tfm(&subreq, tfm);
    283	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
    284				      NULL, NULL);
    285
    286	if (req->cryptlen <= AES_BLOCK_SIZE) {
    287		if (req->cryptlen < AES_BLOCK_SIZE)
    288			return -EINVAL;
    289		cbc_blocks = 1;
    290	}
    291
    292	if (cbc_blocks > 0) {
    293		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    294					   cbc_blocks * AES_BLOCK_SIZE,
    295					   req->iv);
    296
    297		err = skcipher_walk_virt(&walk, &subreq, false) ?:
    298		      cbc_encrypt_walk(&subreq, &walk);
    299		if (err)
    300			return err;
    301
    302		if (req->cryptlen == AES_BLOCK_SIZE)
    303			return 0;
    304
    305		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
    306		if (req->dst != req->src)
    307			dst = scatterwalk_ffwd(sg_dst, req->dst,
    308					       subreq.cryptlen);
    309	}
    310
    311	/* handle ciphertext stealing */
    312	skcipher_request_set_crypt(&subreq, src, dst,
    313				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
    314				   req->iv);
    315
    316	err = skcipher_walk_virt(&walk, &subreq, false);
    317	if (err)
    318		return err;
    319
    320	kernel_neon_begin();
    321	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    322			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
    323	kernel_neon_end();
    324
    325	return skcipher_walk_done(&walk, 0);
    326}
    327
    328static int cts_cbc_decrypt(struct skcipher_request *req)
    329{
    330	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    331	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    332	int err, rounds = 6 + ctx->key_length / 4;
    333	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
    334	struct scatterlist *src = req->src, *dst = req->dst;
    335	struct scatterlist sg_src[2], sg_dst[2];
    336	struct skcipher_request subreq;
    337	struct skcipher_walk walk;
    338
    339	skcipher_request_set_tfm(&subreq, tfm);
    340	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
    341				      NULL, NULL);
    342
    343	if (req->cryptlen <= AES_BLOCK_SIZE) {
    344		if (req->cryptlen < AES_BLOCK_SIZE)
    345			return -EINVAL;
    346		cbc_blocks = 1;
    347	}
    348
    349	if (cbc_blocks > 0) {
    350		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    351					   cbc_blocks * AES_BLOCK_SIZE,
    352					   req->iv);
    353
    354		err = skcipher_walk_virt(&walk, &subreq, false) ?:
    355		      cbc_decrypt_walk(&subreq, &walk);
    356		if (err)
    357			return err;
    358
    359		if (req->cryptlen == AES_BLOCK_SIZE)
    360			return 0;
    361
    362		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
    363		if (req->dst != req->src)
    364			dst = scatterwalk_ffwd(sg_dst, req->dst,
    365					       subreq.cryptlen);
    366	}
    367
    368	/* handle ciphertext stealing */
    369	skcipher_request_set_crypt(&subreq, src, dst,
    370				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
    371				   req->iv);
    372
    373	err = skcipher_walk_virt(&walk, &subreq, false);
    374	if (err)
    375		return err;
    376
    377	kernel_neon_begin();
    378	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    379			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
    380	kernel_neon_end();
    381
    382	return skcipher_walk_done(&walk, 0);
    383}
    384
    385static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
    386{
    387	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    388
    389	ctx->hash = crypto_alloc_shash("sha256", 0, 0);
    390
    391	return PTR_ERR_OR_ZERO(ctx->hash);
    392}
    393
    394static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
    395{
    396	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    397
    398	crypto_free_shash(ctx->hash);
    399}
    400
    401static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
    402{
    403	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    404	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    405	int err, rounds = 6 + ctx->key1.key_length / 4;
    406	struct skcipher_walk walk;
    407	unsigned int blocks;
    408
    409	err = skcipher_walk_virt(&walk, req, false);
    410
    411	blocks = walk.nbytes / AES_BLOCK_SIZE;
    412	if (blocks) {
    413		kernel_neon_begin();
    414		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    415				      ctx->key1.key_enc, rounds, blocks,
    416				      req->iv, ctx->key2.key_enc);
    417		kernel_neon_end();
    418		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    419	}
    420	return err ?: cbc_encrypt_walk(req, &walk);
    421}
    422
    423static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
    424{
    425	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    426	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    427	int err, rounds = 6 + ctx->key1.key_length / 4;
    428	struct skcipher_walk walk;
    429	unsigned int blocks;
    430
    431	err = skcipher_walk_virt(&walk, req, false);
    432
    433	blocks = walk.nbytes / AES_BLOCK_SIZE;
    434	if (blocks) {
    435		kernel_neon_begin();
    436		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    437				      ctx->key1.key_dec, rounds, blocks,
    438				      req->iv, ctx->key2.key_enc);
    439		kernel_neon_end();
    440		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    441	}
    442	return err ?: cbc_decrypt_walk(req, &walk);
    443}
    444
    445static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
    446{
    447	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    448	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    449	int err, rounds = 6 + ctx->key_length / 4;
    450	struct skcipher_walk walk;
    451
    452	err = skcipher_walk_virt(&walk, req, false);
    453
    454	while (walk.nbytes > 0) {
    455		const u8 *src = walk.src.virt.addr;
    456		unsigned int nbytes = walk.nbytes;
    457		u8 *dst = walk.dst.virt.addr;
    458		u8 buf[AES_BLOCK_SIZE];
    459
    460		if (unlikely(nbytes < AES_BLOCK_SIZE))
    461			src = dst = memcpy(buf + sizeof(buf) - nbytes,
    462					   src, nbytes);
    463		else if (nbytes < walk.total)
    464			nbytes &= ~(AES_BLOCK_SIZE - 1);
    465
    466		kernel_neon_begin();
    467		aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
    468				walk.iv);
    469		kernel_neon_end();
    470
    471		if (unlikely(nbytes < AES_BLOCK_SIZE))
    472			memcpy(walk.dst.virt.addr,
    473			       buf + sizeof(buf) - nbytes, nbytes);
    474
    475		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
    476	}
    477
    478	return err;
    479}
    480
    481static int __maybe_unused xts_encrypt(struct skcipher_request *req)
    482{
    483	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    484	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    485	int err, first, rounds = 6 + ctx->key1.key_length / 4;
    486	int tail = req->cryptlen % AES_BLOCK_SIZE;
    487	struct scatterlist sg_src[2], sg_dst[2];
    488	struct skcipher_request subreq;
    489	struct scatterlist *src, *dst;
    490	struct skcipher_walk walk;
    491
    492	if (req->cryptlen < AES_BLOCK_SIZE)
    493		return -EINVAL;
    494
    495	err = skcipher_walk_virt(&walk, req, false);
    496
    497	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
    498		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
    499					      AES_BLOCK_SIZE) - 2;
    500
    501		skcipher_walk_abort(&walk);
    502
    503		skcipher_request_set_tfm(&subreq, tfm);
    504		skcipher_request_set_callback(&subreq,
    505					      skcipher_request_flags(req),
    506					      NULL, NULL);
    507		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    508					   xts_blocks * AES_BLOCK_SIZE,
    509					   req->iv);
    510		req = &subreq;
    511		err = skcipher_walk_virt(&walk, req, false);
    512	} else {
    513		tail = 0;
    514	}
    515
    516	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
    517		int nbytes = walk.nbytes;
    518
    519		if (walk.nbytes < walk.total)
    520			nbytes &= ~(AES_BLOCK_SIZE - 1);
    521
    522		kernel_neon_begin();
    523		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    524				ctx->key1.key_enc, rounds, nbytes,
    525				ctx->key2.key_enc, walk.iv, first);
    526		kernel_neon_end();
    527		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
    528	}
    529
    530	if (err || likely(!tail))
    531		return err;
    532
    533	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
    534	if (req->dst != req->src)
    535		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
    536
    537	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
    538				   req->iv);
    539
    540	err = skcipher_walk_virt(&walk, &subreq, false);
    541	if (err)
    542		return err;
    543
    544	kernel_neon_begin();
    545	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    546			ctx->key1.key_enc, rounds, walk.nbytes,
    547			ctx->key2.key_enc, walk.iv, first);
    548	kernel_neon_end();
    549
    550	return skcipher_walk_done(&walk, 0);
    551}
    552
    553static int __maybe_unused xts_decrypt(struct skcipher_request *req)
    554{
    555	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    556	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    557	int err, first, rounds = 6 + ctx->key1.key_length / 4;
    558	int tail = req->cryptlen % AES_BLOCK_SIZE;
    559	struct scatterlist sg_src[2], sg_dst[2];
    560	struct skcipher_request subreq;
    561	struct scatterlist *src, *dst;
    562	struct skcipher_walk walk;
    563
    564	if (req->cryptlen < AES_BLOCK_SIZE)
    565		return -EINVAL;
    566
    567	err = skcipher_walk_virt(&walk, req, false);
    568
    569	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
    570		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
    571					      AES_BLOCK_SIZE) - 2;
    572
    573		skcipher_walk_abort(&walk);
    574
    575		skcipher_request_set_tfm(&subreq, tfm);
    576		skcipher_request_set_callback(&subreq,
    577					      skcipher_request_flags(req),
    578					      NULL, NULL);
    579		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    580					   xts_blocks * AES_BLOCK_SIZE,
    581					   req->iv);
    582		req = &subreq;
    583		err = skcipher_walk_virt(&walk, req, false);
    584	} else {
    585		tail = 0;
    586	}
    587
    588	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
    589		int nbytes = walk.nbytes;
    590
    591		if (walk.nbytes < walk.total)
    592			nbytes &= ~(AES_BLOCK_SIZE - 1);
    593
    594		kernel_neon_begin();
    595		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    596				ctx->key1.key_dec, rounds, nbytes,
    597				ctx->key2.key_enc, walk.iv, first);
    598		kernel_neon_end();
    599		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
    600	}
    601
    602	if (err || likely(!tail))
    603		return err;
    604
    605	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
    606	if (req->dst != req->src)
    607		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
    608
    609	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
    610				   req->iv);
    611
    612	err = skcipher_walk_virt(&walk, &subreq, false);
    613	if (err)
    614		return err;
    615
    616
    617	kernel_neon_begin();
    618	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    619			ctx->key1.key_dec, rounds, walk.nbytes,
    620			ctx->key2.key_enc, walk.iv, first);
    621	kernel_neon_end();
    622
    623	return skcipher_walk_done(&walk, 0);
    624}
    625
    626static struct skcipher_alg aes_algs[] = { {
    627#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
    628	.base = {
    629		.cra_name		= "ecb(aes)",
    630		.cra_driver_name	= "ecb-aes-" MODE,
    631		.cra_priority		= PRIO,
    632		.cra_blocksize		= AES_BLOCK_SIZE,
    633		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
    634		.cra_module		= THIS_MODULE,
    635	},
    636	.min_keysize	= AES_MIN_KEY_SIZE,
    637	.max_keysize	= AES_MAX_KEY_SIZE,
    638	.setkey		= skcipher_aes_setkey,
    639	.encrypt	= ecb_encrypt,
    640	.decrypt	= ecb_decrypt,
    641}, {
    642	.base = {
    643		.cra_name		= "cbc(aes)",
    644		.cra_driver_name	= "cbc-aes-" MODE,
    645		.cra_priority		= PRIO,
    646		.cra_blocksize		= AES_BLOCK_SIZE,
    647		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
    648		.cra_module		= THIS_MODULE,
    649	},
    650	.min_keysize	= AES_MIN_KEY_SIZE,
    651	.max_keysize	= AES_MAX_KEY_SIZE,
    652	.ivsize		= AES_BLOCK_SIZE,
    653	.setkey		= skcipher_aes_setkey,
    654	.encrypt	= cbc_encrypt,
    655	.decrypt	= cbc_decrypt,
    656}, {
    657	.base = {
    658		.cra_name		= "ctr(aes)",
    659		.cra_driver_name	= "ctr-aes-" MODE,
    660		.cra_priority		= PRIO,
    661		.cra_blocksize		= 1,
    662		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
    663		.cra_module		= THIS_MODULE,
    664	},
    665	.min_keysize	= AES_MIN_KEY_SIZE,
    666	.max_keysize	= AES_MAX_KEY_SIZE,
    667	.ivsize		= AES_BLOCK_SIZE,
    668	.chunksize	= AES_BLOCK_SIZE,
    669	.setkey		= skcipher_aes_setkey,
    670	.encrypt	= ctr_encrypt,
    671	.decrypt	= ctr_encrypt,
    672}, {
    673	.base = {
    674		.cra_name		= "xts(aes)",
    675		.cra_driver_name	= "xts-aes-" MODE,
    676		.cra_priority		= PRIO,
    677		.cra_blocksize		= AES_BLOCK_SIZE,
    678		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
    679		.cra_module		= THIS_MODULE,
    680	},
    681	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
    682	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
    683	.ivsize		= AES_BLOCK_SIZE,
    684	.walksize	= 2 * AES_BLOCK_SIZE,
    685	.setkey		= xts_set_key,
    686	.encrypt	= xts_encrypt,
    687	.decrypt	= xts_decrypt,
    688}, {
    689#endif
    690	.base = {
    691		.cra_name		= "cts(cbc(aes))",
    692		.cra_driver_name	= "cts-cbc-aes-" MODE,
    693		.cra_priority		= PRIO,
    694		.cra_blocksize		= AES_BLOCK_SIZE,
    695		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
    696		.cra_module		= THIS_MODULE,
    697	},
    698	.min_keysize	= AES_MIN_KEY_SIZE,
    699	.max_keysize	= AES_MAX_KEY_SIZE,
    700	.ivsize		= AES_BLOCK_SIZE,
    701	.walksize	= 2 * AES_BLOCK_SIZE,
    702	.setkey		= skcipher_aes_setkey,
    703	.encrypt	= cts_cbc_encrypt,
    704	.decrypt	= cts_cbc_decrypt,
    705}, {
    706	.base = {
    707		.cra_name		= "essiv(cbc(aes),sha256)",
    708		.cra_driver_name	= "essiv-cbc-aes-sha256-" MODE,
    709		.cra_priority		= PRIO + 1,
    710		.cra_blocksize		= AES_BLOCK_SIZE,
    711		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
    712		.cra_module		= THIS_MODULE,
    713	},
    714	.min_keysize	= AES_MIN_KEY_SIZE,
    715	.max_keysize	= AES_MAX_KEY_SIZE,
    716	.ivsize		= AES_BLOCK_SIZE,
    717	.setkey		= essiv_cbc_set_key,
    718	.encrypt	= essiv_cbc_encrypt,
    719	.decrypt	= essiv_cbc_decrypt,
    720	.init		= essiv_cbc_init_tfm,
    721	.exit		= essiv_cbc_exit_tfm,
    722} };
    723
    724static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
    725			 unsigned int key_len)
    726{
    727	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
    728
    729	return aes_expandkey(&ctx->key, in_key, key_len);
    730}
    731
    732static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
    733{
    734	u64 a = be64_to_cpu(x->a);
    735	u64 b = be64_to_cpu(x->b);
    736
    737	y->a = cpu_to_be64((a << 1) | (b >> 63));
    738	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
    739}
    740
    741static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
    742		       unsigned int key_len)
    743{
    744	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
    745	be128 *consts = (be128 *)ctx->consts;
    746	int rounds = 6 + key_len / 4;
    747	int err;
    748
    749	err = cbcmac_setkey(tfm, in_key, key_len);
    750	if (err)
    751		return err;
    752
    753	/* encrypt the zero vector */
    754	kernel_neon_begin();
    755	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
    756			rounds, 1);
    757	kernel_neon_end();
    758
    759	cmac_gf128_mul_by_x(consts, consts);
    760	cmac_gf128_mul_by_x(consts + 1, consts);
    761
    762	return 0;
    763}
    764
    765static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
    766		       unsigned int key_len)
    767{
    768	static u8 const ks[3][AES_BLOCK_SIZE] = {
    769		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
    770		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
    771		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
    772	};
    773
    774	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
    775	int rounds = 6 + key_len / 4;
    776	u8 key[AES_BLOCK_SIZE];
    777	int err;
    778
    779	err = cbcmac_setkey(tfm, in_key, key_len);
    780	if (err)
    781		return err;
    782
    783	kernel_neon_begin();
    784	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
    785	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
    786	kernel_neon_end();
    787
    788	return cbcmac_setkey(tfm, key, sizeof(key));
    789}
    790
    791static int mac_init(struct shash_desc *desc)
    792{
    793	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
    794
    795	memset(ctx->dg, 0, AES_BLOCK_SIZE);
    796	ctx->len = 0;
    797
    798	return 0;
    799}
    800
    801static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
    802			  u8 dg[], int enc_before, int enc_after)
    803{
    804	int rounds = 6 + ctx->key_length / 4;
    805
    806	if (crypto_simd_usable()) {
    807		int rem;
    808
    809		do {
    810			kernel_neon_begin();
    811			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
    812					     dg, enc_before, enc_after);
    813			kernel_neon_end();
    814			in += (blocks - rem) * AES_BLOCK_SIZE;
    815			blocks = rem;
    816			enc_before = 0;
    817		} while (blocks);
    818	} else {
    819		if (enc_before)
    820			aes_encrypt(ctx, dg, dg);
    821
    822		while (blocks--) {
    823			crypto_xor(dg, in, AES_BLOCK_SIZE);
    824			in += AES_BLOCK_SIZE;
    825
    826			if (blocks || enc_after)
    827				aes_encrypt(ctx, dg, dg);
    828		}
    829	}
    830}
    831
    832static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
    833{
    834	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
    835	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
    836
    837	while (len > 0) {
    838		unsigned int l;
    839
    840		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
    841		    (ctx->len + len) > AES_BLOCK_SIZE) {
    842
    843			int blocks = len / AES_BLOCK_SIZE;
    844
    845			len %= AES_BLOCK_SIZE;
    846
    847			mac_do_update(&tctx->key, p, blocks, ctx->dg,
    848				      (ctx->len != 0), (len != 0));
    849
    850			p += blocks * AES_BLOCK_SIZE;
    851
    852			if (!len) {
    853				ctx->len = AES_BLOCK_SIZE;
    854				break;
    855			}
    856			ctx->len = 0;
    857		}
    858
    859		l = min(len, AES_BLOCK_SIZE - ctx->len);
    860
    861		if (l <= AES_BLOCK_SIZE) {
    862			crypto_xor(ctx->dg + ctx->len, p, l);
    863			ctx->len += l;
    864			len -= l;
    865			p += l;
    866		}
    867	}
    868
    869	return 0;
    870}
    871
    872static int cbcmac_final(struct shash_desc *desc, u8 *out)
    873{
    874	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
    875	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
    876
    877	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
    878
    879	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
    880
    881	return 0;
    882}
    883
    884static int cmac_final(struct shash_desc *desc, u8 *out)
    885{
    886	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
    887	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
    888	u8 *consts = tctx->consts;
    889
    890	if (ctx->len != AES_BLOCK_SIZE) {
    891		ctx->dg[ctx->len] ^= 0x80;
    892		consts += AES_BLOCK_SIZE;
    893	}
    894
    895	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
    896
    897	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
    898
    899	return 0;
    900}
    901
    902static struct shash_alg mac_algs[] = { {
    903	.base.cra_name		= "cmac(aes)",
    904	.base.cra_driver_name	= "cmac-aes-" MODE,
    905	.base.cra_priority	= PRIO,
    906	.base.cra_blocksize	= AES_BLOCK_SIZE,
    907	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
    908				  2 * AES_BLOCK_SIZE,
    909	.base.cra_module	= THIS_MODULE,
    910
    911	.digestsize		= AES_BLOCK_SIZE,
    912	.init			= mac_init,
    913	.update			= mac_update,
    914	.final			= cmac_final,
    915	.setkey			= cmac_setkey,
    916	.descsize		= sizeof(struct mac_desc_ctx),
    917}, {
    918	.base.cra_name		= "xcbc(aes)",
    919	.base.cra_driver_name	= "xcbc-aes-" MODE,
    920	.base.cra_priority	= PRIO,
    921	.base.cra_blocksize	= AES_BLOCK_SIZE,
    922	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
    923				  2 * AES_BLOCK_SIZE,
    924	.base.cra_module	= THIS_MODULE,
    925
    926	.digestsize		= AES_BLOCK_SIZE,
    927	.init			= mac_init,
    928	.update			= mac_update,
    929	.final			= cmac_final,
    930	.setkey			= xcbc_setkey,
    931	.descsize		= sizeof(struct mac_desc_ctx),
    932}, {
    933	.base.cra_name		= "cbcmac(aes)",
    934	.base.cra_driver_name	= "cbcmac-aes-" MODE,
    935	.base.cra_priority	= PRIO,
    936	.base.cra_blocksize	= 1,
    937	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
    938	.base.cra_module	= THIS_MODULE,
    939
    940	.digestsize		= AES_BLOCK_SIZE,
    941	.init			= mac_init,
    942	.update			= mac_update,
    943	.final			= cbcmac_final,
    944	.setkey			= cbcmac_setkey,
    945	.descsize		= sizeof(struct mac_desc_ctx),
    946} };
    947
    948static void aes_exit(void)
    949{
    950	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
    951	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    952}
    953
    954static int __init aes_init(void)
    955{
    956	int err;
    957
    958	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    959	if (err)
    960		return err;
    961
    962	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
    963	if (err)
    964		goto unregister_ciphers;
    965
    966	return 0;
    967
    968unregister_ciphers:
    969	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    970	return err;
    971}
    972
    973#ifdef USE_V8_CRYPTO_EXTENSIONS
    974module_cpu_feature_match(AES, aes_init);
    975#else
    976module_init(aes_init);
    977EXPORT_SYMBOL(neon_aes_ecb_encrypt);
    978EXPORT_SYMBOL(neon_aes_cbc_encrypt);
    979EXPORT_SYMBOL(neon_aes_ctr_encrypt);
    980EXPORT_SYMBOL(neon_aes_xts_encrypt);
    981EXPORT_SYMBOL(neon_aes_xts_decrypt);
    982#endif
    983module_exit(aes_exit);