cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

aes-ce-glue.c (20023B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * aes-ce-glue.c - wrapper code for ARMv8 AES
      4 *
      5 * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
      6 */
      7
      8#include <asm/hwcap.h>
      9#include <asm/neon.h>
     10#include <asm/simd.h>
     11#include <asm/unaligned.h>
     12#include <crypto/aes.h>
     13#include <crypto/ctr.h>
     14#include <crypto/internal/simd.h>
     15#include <crypto/internal/skcipher.h>
     16#include <crypto/scatterwalk.h>
     17#include <linux/cpufeature.h>
     18#include <linux/module.h>
     19#include <crypto/xts.h>
     20
     21MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
     22MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
     23MODULE_LICENSE("GPL v2");
     24
     25/* defined in aes-ce-core.S */
     26asmlinkage u32 ce_aes_sub(u32 input);
     27asmlinkage void ce_aes_invert(void *dst, void *src);
     28
     29asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
     30				   int rounds, int blocks);
     31asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
     32				   int rounds, int blocks);
     33
     34asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
     35				   int rounds, int blocks, u8 iv[]);
     36asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
     37				   int rounds, int blocks, u8 iv[]);
     38asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
     39				   int rounds, int bytes, u8 const iv[]);
     40asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
     41				   int rounds, int bytes, u8 const iv[]);
     42
     43asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
     44				   int rounds, int blocks, u8 ctr[]);
     45
     46asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
     47				   int rounds, int bytes, u8 iv[],
     48				   u32 const rk2[], int first);
     49asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
     50				   int rounds, int bytes, u8 iv[],
     51				   u32 const rk2[], int first);
     52
     53struct aes_block {
     54	u8 b[AES_BLOCK_SIZE];
     55};
     56
     57static int num_rounds(struct crypto_aes_ctx *ctx)
     58{
     59	/*
     60	 * # of rounds specified by AES:
     61	 * 128 bit key		10 rounds
     62	 * 192 bit key		12 rounds
     63	 * 256 bit key		14 rounds
     64	 * => n byte key	=> 6 + (n/4) rounds
     65	 */
     66	return 6 + ctx->key_length / 4;
     67}
     68
     69static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
     70			    unsigned int key_len)
     71{
     72	/*
     73	 * The AES key schedule round constants
     74	 */
     75	static u8 const rcon[] = {
     76		0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
     77	};
     78
     79	u32 kwords = key_len / sizeof(u32);
     80	struct aes_block *key_enc, *key_dec;
     81	int i, j;
     82
     83	if (key_len != AES_KEYSIZE_128 &&
     84	    key_len != AES_KEYSIZE_192 &&
     85	    key_len != AES_KEYSIZE_256)
     86		return -EINVAL;
     87
     88	ctx->key_length = key_len;
     89	for (i = 0; i < kwords; i++)
     90		ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
     91
     92	kernel_neon_begin();
     93	for (i = 0; i < sizeof(rcon); i++) {
     94		u32 *rki = ctx->key_enc + (i * kwords);
     95		u32 *rko = rki + kwords;
     96
     97		rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
     98		rko[0] = rko[0] ^ rki[0] ^ rcon[i];
     99		rko[1] = rko[0] ^ rki[1];
    100		rko[2] = rko[1] ^ rki[2];
    101		rko[3] = rko[2] ^ rki[3];
    102
    103		if (key_len == AES_KEYSIZE_192) {
    104			if (i >= 7)
    105				break;
    106			rko[4] = rko[3] ^ rki[4];
    107			rko[5] = rko[4] ^ rki[5];
    108		} else if (key_len == AES_KEYSIZE_256) {
    109			if (i >= 6)
    110				break;
    111			rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
    112			rko[5] = rko[4] ^ rki[5];
    113			rko[6] = rko[5] ^ rki[6];
    114			rko[7] = rko[6] ^ rki[7];
    115		}
    116	}
    117
    118	/*
    119	 * Generate the decryption keys for the Equivalent Inverse Cipher.
    120	 * This involves reversing the order of the round keys, and applying
    121	 * the Inverse Mix Columns transformation on all but the first and
    122	 * the last one.
    123	 */
    124	key_enc = (struct aes_block *)ctx->key_enc;
    125	key_dec = (struct aes_block *)ctx->key_dec;
    126	j = num_rounds(ctx);
    127
    128	key_dec[0] = key_enc[j];
    129	for (i = 1, j--; j > 0; i++, j--)
    130		ce_aes_invert(key_dec + i, key_enc + j);
    131	key_dec[i] = key_enc[0];
    132
    133	kernel_neon_end();
    134	return 0;
    135}
    136
    137static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    138			 unsigned int key_len)
    139{
    140	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    141
    142	return ce_aes_expandkey(ctx, in_key, key_len);
    143}
    144
    145struct crypto_aes_xts_ctx {
    146	struct crypto_aes_ctx key1;
    147	struct crypto_aes_ctx __aligned(8) key2;
    148};
    149
    150static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
    151		       unsigned int key_len)
    152{
    153	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    154	int ret;
    155
    156	ret = xts_verify_key(tfm, in_key, key_len);
    157	if (ret)
    158		return ret;
    159
    160	ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
    161	if (!ret)
    162		ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
    163				       key_len / 2);
    164	return ret;
    165}
    166
    167static int ecb_encrypt(struct skcipher_request *req)
    168{
    169	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    170	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    171	struct skcipher_walk walk;
    172	unsigned int blocks;
    173	int err;
    174
    175	err = skcipher_walk_virt(&walk, req, false);
    176
    177	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
    178		kernel_neon_begin();
    179		ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    180				   ctx->key_enc, num_rounds(ctx), blocks);
    181		kernel_neon_end();
    182		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    183	}
    184	return err;
    185}
    186
    187static int ecb_decrypt(struct skcipher_request *req)
    188{
    189	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    190	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    191	struct skcipher_walk walk;
    192	unsigned int blocks;
    193	int err;
    194
    195	err = skcipher_walk_virt(&walk, req, false);
    196
    197	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
    198		kernel_neon_begin();
    199		ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    200				   ctx->key_dec, num_rounds(ctx), blocks);
    201		kernel_neon_end();
    202		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    203	}
    204	return err;
    205}
    206
    207static int cbc_encrypt_walk(struct skcipher_request *req,
    208			    struct skcipher_walk *walk)
    209{
    210	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    211	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    212	unsigned int blocks;
    213	int err = 0;
    214
    215	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
    216		kernel_neon_begin();
    217		ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
    218				   ctx->key_enc, num_rounds(ctx), blocks,
    219				   walk->iv);
    220		kernel_neon_end();
    221		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
    222	}
    223	return err;
    224}
    225
    226static int cbc_encrypt(struct skcipher_request *req)
    227{
    228	struct skcipher_walk walk;
    229	int err;
    230
    231	err = skcipher_walk_virt(&walk, req, false);
    232	if (err)
    233		return err;
    234	return cbc_encrypt_walk(req, &walk);
    235}
    236
    237static int cbc_decrypt_walk(struct skcipher_request *req,
    238			    struct skcipher_walk *walk)
    239{
    240	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    241	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    242	unsigned int blocks;
    243	int err = 0;
    244
    245	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
    246		kernel_neon_begin();
    247		ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
    248				   ctx->key_dec, num_rounds(ctx), blocks,
    249				   walk->iv);
    250		kernel_neon_end();
    251		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
    252	}
    253	return err;
    254}
    255
    256static int cbc_decrypt(struct skcipher_request *req)
    257{
    258	struct skcipher_walk walk;
    259	int err;
    260
    261	err = skcipher_walk_virt(&walk, req, false);
    262	if (err)
    263		return err;
    264	return cbc_decrypt_walk(req, &walk);
    265}
    266
    267static int cts_cbc_encrypt(struct skcipher_request *req)
    268{
    269	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    270	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    271	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
    272	struct scatterlist *src = req->src, *dst = req->dst;
    273	struct scatterlist sg_src[2], sg_dst[2];
    274	struct skcipher_request subreq;
    275	struct skcipher_walk walk;
    276	int err;
    277
    278	skcipher_request_set_tfm(&subreq, tfm);
    279	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
    280				      NULL, NULL);
    281
    282	if (req->cryptlen <= AES_BLOCK_SIZE) {
    283		if (req->cryptlen < AES_BLOCK_SIZE)
    284			return -EINVAL;
    285		cbc_blocks = 1;
    286	}
    287
    288	if (cbc_blocks > 0) {
    289		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    290					   cbc_blocks * AES_BLOCK_SIZE,
    291					   req->iv);
    292
    293		err = skcipher_walk_virt(&walk, &subreq, false) ?:
    294		      cbc_encrypt_walk(&subreq, &walk);
    295		if (err)
    296			return err;
    297
    298		if (req->cryptlen == AES_BLOCK_SIZE)
    299			return 0;
    300
    301		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
    302		if (req->dst != req->src)
    303			dst = scatterwalk_ffwd(sg_dst, req->dst,
    304					       subreq.cryptlen);
    305	}
    306
    307	/* handle ciphertext stealing */
    308	skcipher_request_set_crypt(&subreq, src, dst,
    309				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
    310				   req->iv);
    311
    312	err = skcipher_walk_virt(&walk, &subreq, false);
    313	if (err)
    314		return err;
    315
    316	kernel_neon_begin();
    317	ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    318			       ctx->key_enc, num_rounds(ctx), walk.nbytes,
    319			       walk.iv);
    320	kernel_neon_end();
    321
    322	return skcipher_walk_done(&walk, 0);
    323}
    324
    325static int cts_cbc_decrypt(struct skcipher_request *req)
    326{
    327	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    328	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    329	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
    330	struct scatterlist *src = req->src, *dst = req->dst;
    331	struct scatterlist sg_src[2], sg_dst[2];
    332	struct skcipher_request subreq;
    333	struct skcipher_walk walk;
    334	int err;
    335
    336	skcipher_request_set_tfm(&subreq, tfm);
    337	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
    338				      NULL, NULL);
    339
    340	if (req->cryptlen <= AES_BLOCK_SIZE) {
    341		if (req->cryptlen < AES_BLOCK_SIZE)
    342			return -EINVAL;
    343		cbc_blocks = 1;
    344	}
    345
    346	if (cbc_blocks > 0) {
    347		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    348					   cbc_blocks * AES_BLOCK_SIZE,
    349					   req->iv);
    350
    351		err = skcipher_walk_virt(&walk, &subreq, false) ?:
    352		      cbc_decrypt_walk(&subreq, &walk);
    353		if (err)
    354			return err;
    355
    356		if (req->cryptlen == AES_BLOCK_SIZE)
    357			return 0;
    358
    359		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
    360		if (req->dst != req->src)
    361			dst = scatterwalk_ffwd(sg_dst, req->dst,
    362					       subreq.cryptlen);
    363	}
    364
    365	/* handle ciphertext stealing */
    366	skcipher_request_set_crypt(&subreq, src, dst,
    367				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
    368				   req->iv);
    369
    370	err = skcipher_walk_virt(&walk, &subreq, false);
    371	if (err)
    372		return err;
    373
    374	kernel_neon_begin();
    375	ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    376			       ctx->key_dec, num_rounds(ctx), walk.nbytes,
    377			       walk.iv);
    378	kernel_neon_end();
    379
    380	return skcipher_walk_done(&walk, 0);
    381}
    382
    383static int ctr_encrypt(struct skcipher_request *req)
    384{
    385	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    386	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    387	struct skcipher_walk walk;
    388	int err, blocks;
    389
    390	err = skcipher_walk_virt(&walk, req, false);
    391
    392	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
    393		kernel_neon_begin();
    394		ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    395				   ctx->key_enc, num_rounds(ctx), blocks,
    396				   walk.iv);
    397		kernel_neon_end();
    398		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    399	}
    400	if (walk.nbytes) {
    401		u8 __aligned(8) tail[AES_BLOCK_SIZE];
    402		unsigned int nbytes = walk.nbytes;
    403		u8 *tdst = walk.dst.virt.addr;
    404		u8 *tsrc = walk.src.virt.addr;
    405
    406		/*
    407		 * Tell aes_ctr_encrypt() to process a tail block.
    408		 */
    409		blocks = -1;
    410
    411		kernel_neon_begin();
    412		ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
    413				   blocks, walk.iv);
    414		kernel_neon_end();
    415		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
    416		err = skcipher_walk_done(&walk, 0);
    417	}
    418	return err;
    419}
    420
    421static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
    422{
    423	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
    424	unsigned long flags;
    425
    426	/*
    427	 * Temporarily disable interrupts to avoid races where
    428	 * cachelines are evicted when the CPU is interrupted
    429	 * to do something else.
    430	 */
    431	local_irq_save(flags);
    432	aes_encrypt(ctx, dst, src);
    433	local_irq_restore(flags);
    434}
    435
    436static int ctr_encrypt_sync(struct skcipher_request *req)
    437{
    438	if (!crypto_simd_usable())
    439		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
    440
    441	return ctr_encrypt(req);
    442}
    443
    444static int xts_encrypt(struct skcipher_request *req)
    445{
    446	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    447	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    448	int err, first, rounds = num_rounds(&ctx->key1);
    449	int tail = req->cryptlen % AES_BLOCK_SIZE;
    450	struct scatterlist sg_src[2], sg_dst[2];
    451	struct skcipher_request subreq;
    452	struct scatterlist *src, *dst;
    453	struct skcipher_walk walk;
    454
    455	if (req->cryptlen < AES_BLOCK_SIZE)
    456		return -EINVAL;
    457
    458	err = skcipher_walk_virt(&walk, req, false);
    459
    460	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
    461		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
    462					      AES_BLOCK_SIZE) - 2;
    463
    464		skcipher_walk_abort(&walk);
    465
    466		skcipher_request_set_tfm(&subreq, tfm);
    467		skcipher_request_set_callback(&subreq,
    468					      skcipher_request_flags(req),
    469					      NULL, NULL);
    470		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    471					   xts_blocks * AES_BLOCK_SIZE,
    472					   req->iv);
    473		req = &subreq;
    474		err = skcipher_walk_virt(&walk, req, false);
    475	} else {
    476		tail = 0;
    477	}
    478
    479	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
    480		int nbytes = walk.nbytes;
    481
    482		if (walk.nbytes < walk.total)
    483			nbytes &= ~(AES_BLOCK_SIZE - 1);
    484
    485		kernel_neon_begin();
    486		ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    487				   ctx->key1.key_enc, rounds, nbytes, walk.iv,
    488				   ctx->key2.key_enc, first);
    489		kernel_neon_end();
    490		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
    491	}
    492
    493	if (err || likely(!tail))
    494		return err;
    495
    496	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
    497	if (req->dst != req->src)
    498		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
    499
    500	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
    501				   req->iv);
    502
    503	err = skcipher_walk_virt(&walk, req, false);
    504	if (err)
    505		return err;
    506
    507	kernel_neon_begin();
    508	ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    509			   ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
    510			   ctx->key2.key_enc, first);
    511	kernel_neon_end();
    512
    513	return skcipher_walk_done(&walk, 0);
    514}
    515
    516static int xts_decrypt(struct skcipher_request *req)
    517{
    518	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    519	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    520	int err, first, rounds = num_rounds(&ctx->key1);
    521	int tail = req->cryptlen % AES_BLOCK_SIZE;
    522	struct scatterlist sg_src[2], sg_dst[2];
    523	struct skcipher_request subreq;
    524	struct scatterlist *src, *dst;
    525	struct skcipher_walk walk;
    526
    527	if (req->cryptlen < AES_BLOCK_SIZE)
    528		return -EINVAL;
    529
    530	err = skcipher_walk_virt(&walk, req, false);
    531
    532	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
    533		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
    534					      AES_BLOCK_SIZE) - 2;
    535
    536		skcipher_walk_abort(&walk);
    537
    538		skcipher_request_set_tfm(&subreq, tfm);
    539		skcipher_request_set_callback(&subreq,
    540					      skcipher_request_flags(req),
    541					      NULL, NULL);
    542		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    543					   xts_blocks * AES_BLOCK_SIZE,
    544					   req->iv);
    545		req = &subreq;
    546		err = skcipher_walk_virt(&walk, req, false);
    547	} else {
    548		tail = 0;
    549	}
    550
    551	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
    552		int nbytes = walk.nbytes;
    553
    554		if (walk.nbytes < walk.total)
    555			nbytes &= ~(AES_BLOCK_SIZE - 1);
    556
    557		kernel_neon_begin();
    558		ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    559				   ctx->key1.key_dec, rounds, nbytes, walk.iv,
    560				   ctx->key2.key_enc, first);
    561		kernel_neon_end();
    562		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
    563	}
    564
    565	if (err || likely(!tail))
    566		return err;
    567
    568	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
    569	if (req->dst != req->src)
    570		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
    571
    572	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
    573				   req->iv);
    574
    575	err = skcipher_walk_virt(&walk, req, false);
    576	if (err)
    577		return err;
    578
    579	kernel_neon_begin();
    580	ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    581			   ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
    582			   ctx->key2.key_enc, first);
    583	kernel_neon_end();
    584
    585	return skcipher_walk_done(&walk, 0);
    586}
    587
    588static struct skcipher_alg aes_algs[] = { {
    589	.base.cra_name		= "__ecb(aes)",
    590	.base.cra_driver_name	= "__ecb-aes-ce",
    591	.base.cra_priority	= 300,
    592	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    593	.base.cra_blocksize	= AES_BLOCK_SIZE,
    594	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
    595	.base.cra_module	= THIS_MODULE,
    596
    597	.min_keysize		= AES_MIN_KEY_SIZE,
    598	.max_keysize		= AES_MAX_KEY_SIZE,
    599	.setkey			= ce_aes_setkey,
    600	.encrypt		= ecb_encrypt,
    601	.decrypt		= ecb_decrypt,
    602}, {
    603	.base.cra_name		= "__cbc(aes)",
    604	.base.cra_driver_name	= "__cbc-aes-ce",
    605	.base.cra_priority	= 300,
    606	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    607	.base.cra_blocksize	= AES_BLOCK_SIZE,
    608	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
    609	.base.cra_module	= THIS_MODULE,
    610
    611	.min_keysize		= AES_MIN_KEY_SIZE,
    612	.max_keysize		= AES_MAX_KEY_SIZE,
    613	.ivsize			= AES_BLOCK_SIZE,
    614	.setkey			= ce_aes_setkey,
    615	.encrypt		= cbc_encrypt,
    616	.decrypt		= cbc_decrypt,
    617}, {
    618	.base.cra_name		= "__cts(cbc(aes))",
    619	.base.cra_driver_name	= "__cts-cbc-aes-ce",
    620	.base.cra_priority	= 300,
    621	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    622	.base.cra_blocksize	= AES_BLOCK_SIZE,
    623	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
    624	.base.cra_module	= THIS_MODULE,
    625
    626	.min_keysize		= AES_MIN_KEY_SIZE,
    627	.max_keysize		= AES_MAX_KEY_SIZE,
    628	.ivsize			= AES_BLOCK_SIZE,
    629	.walksize		= 2 * AES_BLOCK_SIZE,
    630	.setkey			= ce_aes_setkey,
    631	.encrypt		= cts_cbc_encrypt,
    632	.decrypt		= cts_cbc_decrypt,
    633}, {
    634	.base.cra_name		= "__ctr(aes)",
    635	.base.cra_driver_name	= "__ctr-aes-ce",
    636	.base.cra_priority	= 300,
    637	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    638	.base.cra_blocksize	= 1,
    639	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
    640	.base.cra_module	= THIS_MODULE,
    641
    642	.min_keysize		= AES_MIN_KEY_SIZE,
    643	.max_keysize		= AES_MAX_KEY_SIZE,
    644	.ivsize			= AES_BLOCK_SIZE,
    645	.chunksize		= AES_BLOCK_SIZE,
    646	.setkey			= ce_aes_setkey,
    647	.encrypt		= ctr_encrypt,
    648	.decrypt		= ctr_encrypt,
    649}, {
    650	.base.cra_name		= "ctr(aes)",
    651	.base.cra_driver_name	= "ctr-aes-ce-sync",
    652	.base.cra_priority	= 300 - 1,
    653	.base.cra_blocksize	= 1,
    654	.base.cra_ctxsize	= sizeof(struct crypto_aes_ctx),
    655	.base.cra_module	= THIS_MODULE,
    656
    657	.min_keysize		= AES_MIN_KEY_SIZE,
    658	.max_keysize		= AES_MAX_KEY_SIZE,
    659	.ivsize			= AES_BLOCK_SIZE,
    660	.chunksize		= AES_BLOCK_SIZE,
    661	.setkey			= ce_aes_setkey,
    662	.encrypt		= ctr_encrypt_sync,
    663	.decrypt		= ctr_encrypt_sync,
    664}, {
    665	.base.cra_name		= "__xts(aes)",
    666	.base.cra_driver_name	= "__xts-aes-ce",
    667	.base.cra_priority	= 300,
    668	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    669	.base.cra_blocksize	= AES_BLOCK_SIZE,
    670	.base.cra_ctxsize	= sizeof(struct crypto_aes_xts_ctx),
    671	.base.cra_module	= THIS_MODULE,
    672
    673	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
    674	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
    675	.ivsize			= AES_BLOCK_SIZE,
    676	.walksize		= 2 * AES_BLOCK_SIZE,
    677	.setkey			= xts_set_key,
    678	.encrypt		= xts_encrypt,
    679	.decrypt		= xts_decrypt,
    680} };
    681
    682static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
    683
    684static void aes_exit(void)
    685{
    686	int i;
    687
    688	for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
    689		simd_skcipher_free(aes_simd_algs[i]);
    690
    691	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    692}
    693
    694static int __init aes_init(void)
    695{
    696	struct simd_skcipher_alg *simd;
    697	const char *basename;
    698	const char *algname;
    699	const char *drvname;
    700	int err;
    701	int i;
    702
    703	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    704	if (err)
    705		return err;
    706
    707	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
    708		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
    709			continue;
    710
    711		algname = aes_algs[i].base.cra_name + 2;
    712		drvname = aes_algs[i].base.cra_driver_name + 2;
    713		basename = aes_algs[i].base.cra_driver_name;
    714		simd = simd_skcipher_create_compat(algname, drvname, basename);
    715		err = PTR_ERR(simd);
    716		if (IS_ERR(simd))
    717			goto unregister_simds;
    718
    719		aes_simd_algs[i] = simd;
    720	}
    721
    722	return 0;
    723
    724unregister_simds:
    725	aes_exit();
    726	return err;
    727}
    728
    729module_cpu_feature_match(AES, aes_init);
    730module_exit(aes_exit);