cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

aes-neonbs-glue.c (12350B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Bit sliced AES using NEON instructions
      4 *
      5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
      6 */
      7
      8#include <asm/neon.h>
      9#include <asm/simd.h>
     10#include <crypto/aes.h>
     11#include <crypto/ctr.h>
     12#include <crypto/internal/simd.h>
     13#include <crypto/internal/skcipher.h>
     14#include <crypto/scatterwalk.h>
     15#include <crypto/xts.h>
     16#include <linux/module.h>
     17
     18MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
     19MODULE_LICENSE("GPL v2");
     20
     21MODULE_ALIAS_CRYPTO("ecb(aes)");
     22MODULE_ALIAS_CRYPTO("cbc(aes)");
     23MODULE_ALIAS_CRYPTO("ctr(aes)");
     24MODULE_ALIAS_CRYPTO("xts(aes)");
     25
     26asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
     27
     28asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
     29				  int rounds, int blocks);
     30asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
     31				  int rounds, int blocks);
     32
     33asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
     34				  int rounds, int blocks, u8 iv[]);
     35
     36asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
     37				  int rounds, int blocks, u8 iv[]);
     38
     39asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
     40				  int rounds, int blocks, u8 iv[]);
     41asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
     42				  int rounds, int blocks, u8 iv[]);
     43
     44/* borrowed from aes-neon-blk.ko */
     45asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
     46				     int rounds, int blocks);
     47asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
     48				     int rounds, int blocks, u8 iv[]);
     49asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
     50				     int rounds, int bytes, u8 ctr[]);
     51asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
     52				     u32 const rk1[], int rounds, int bytes,
     53				     u32 const rk2[], u8 iv[], int first);
     54asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
     55				     u32 const rk1[], int rounds, int bytes,
     56				     u32 const rk2[], u8 iv[], int first);
     57
     58struct aesbs_ctx {
     59	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32];
     60	int	rounds;
     61} __aligned(AES_BLOCK_SIZE);
     62
     63struct aesbs_cbc_ctr_ctx {
     64	struct aesbs_ctx	key;
     65	u32			enc[AES_MAX_KEYLENGTH_U32];
     66};
     67
     68struct aesbs_xts_ctx {
     69	struct aesbs_ctx	key;
     70	u32			twkey[AES_MAX_KEYLENGTH_U32];
     71	struct crypto_aes_ctx	cts;
     72};
     73
     74static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
     75			unsigned int key_len)
     76{
     77	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
     78	struct crypto_aes_ctx rk;
     79	int err;
     80
     81	err = aes_expandkey(&rk, in_key, key_len);
     82	if (err)
     83		return err;
     84
     85	ctx->rounds = 6 + key_len / 4;
     86
     87	kernel_neon_begin();
     88	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
     89	kernel_neon_end();
     90
     91	return 0;
     92}
     93
     94static int __ecb_crypt(struct skcipher_request *req,
     95		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
     96				  int rounds, int blocks))
     97{
     98	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     99	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
    100	struct skcipher_walk walk;
    101	int err;
    102
    103	err = skcipher_walk_virt(&walk, req, false);
    104
    105	while (walk.nbytes >= AES_BLOCK_SIZE) {
    106		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    107
    108		if (walk.nbytes < walk.total)
    109			blocks = round_down(blocks,
    110					    walk.stride / AES_BLOCK_SIZE);
    111
    112		kernel_neon_begin();
    113		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
    114		   ctx->rounds, blocks);
    115		kernel_neon_end();
    116		err = skcipher_walk_done(&walk,
    117					 walk.nbytes - blocks * AES_BLOCK_SIZE);
    118	}
    119
    120	return err;
    121}
    122
    123static int ecb_encrypt(struct skcipher_request *req)
    124{
    125	return __ecb_crypt(req, aesbs_ecb_encrypt);
    126}
    127
    128static int ecb_decrypt(struct skcipher_request *req)
    129{
    130	return __ecb_crypt(req, aesbs_ecb_decrypt);
    131}
    132
    133static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    134			    unsigned int key_len)
    135{
    136	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    137	struct crypto_aes_ctx rk;
    138	int err;
    139
    140	err = aes_expandkey(&rk, in_key, key_len);
    141	if (err)
    142		return err;
    143
    144	ctx->key.rounds = 6 + key_len / 4;
    145
    146	memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
    147
    148	kernel_neon_begin();
    149	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
    150	kernel_neon_end();
    151	memzero_explicit(&rk, sizeof(rk));
    152
    153	return 0;
    154}
    155
    156static int cbc_encrypt(struct skcipher_request *req)
    157{
    158	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    159	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    160	struct skcipher_walk walk;
    161	int err;
    162
    163	err = skcipher_walk_virt(&walk, req, false);
    164
    165	while (walk.nbytes >= AES_BLOCK_SIZE) {
    166		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    167
    168		/* fall back to the non-bitsliced NEON implementation */
    169		kernel_neon_begin();
    170		neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
    171				     ctx->enc, ctx->key.rounds, blocks,
    172				     walk.iv);
    173		kernel_neon_end();
    174		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
    175	}
    176	return err;
    177}
    178
    179static int cbc_decrypt(struct skcipher_request *req)
    180{
    181	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    182	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    183	struct skcipher_walk walk;
    184	int err;
    185
    186	err = skcipher_walk_virt(&walk, req, false);
    187
    188	while (walk.nbytes >= AES_BLOCK_SIZE) {
    189		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    190
    191		if (walk.nbytes < walk.total)
    192			blocks = round_down(blocks,
    193					    walk.stride / AES_BLOCK_SIZE);
    194
    195		kernel_neon_begin();
    196		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    197				  ctx->key.rk, ctx->key.rounds, blocks,
    198				  walk.iv);
    199		kernel_neon_end();
    200		err = skcipher_walk_done(&walk,
    201					 walk.nbytes - blocks * AES_BLOCK_SIZE);
    202	}
    203
    204	return err;
    205}
    206
    207static int ctr_encrypt(struct skcipher_request *req)
    208{
    209	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    210	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    211	struct skcipher_walk walk;
    212	int err;
    213
    214	err = skcipher_walk_virt(&walk, req, false);
    215
    216	while (walk.nbytes > 0) {
    217		int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
    218		int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE);
    219		const u8 *src = walk.src.virt.addr;
    220		u8 *dst = walk.dst.virt.addr;
    221
    222		kernel_neon_begin();
    223		if (blocks >= 8) {
    224			aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
    225					  blocks, walk.iv);
    226			dst += blocks * AES_BLOCK_SIZE;
    227			src += blocks * AES_BLOCK_SIZE;
    228		}
    229		if (nbytes && walk.nbytes == walk.total) {
    230			neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
    231					     nbytes, walk.iv);
    232			nbytes = 0;
    233		}
    234		kernel_neon_end();
    235		err = skcipher_walk_done(&walk, nbytes);
    236	}
    237	return err;
    238}
    239
    240static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    241			    unsigned int key_len)
    242{
    243	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    244	struct crypto_aes_ctx rk;
    245	int err;
    246
    247	err = xts_verify_key(tfm, in_key, key_len);
    248	if (err)
    249		return err;
    250
    251	key_len /= 2;
    252	err = aes_expandkey(&ctx->cts, in_key, key_len);
    253	if (err)
    254		return err;
    255
    256	err = aes_expandkey(&rk, in_key + key_len, key_len);
    257	if (err)
    258		return err;
    259
    260	memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
    261
    262	return aesbs_setkey(tfm, in_key, key_len);
    263}
    264
    265static int __xts_crypt(struct skcipher_request *req, bool encrypt,
    266		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
    267				  int rounds, int blocks, u8 iv[]))
    268{
    269	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    270	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    271	int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
    272	struct scatterlist sg_src[2], sg_dst[2];
    273	struct skcipher_request subreq;
    274	struct scatterlist *src, *dst;
    275	struct skcipher_walk walk;
    276	int nbytes, err;
    277	int first = 1;
    278	u8 *out, *in;
    279
    280	if (req->cryptlen < AES_BLOCK_SIZE)
    281		return -EINVAL;
    282
    283	/* ensure that the cts tail is covered by a single step */
    284	if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
    285		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
    286					      AES_BLOCK_SIZE) - 2;
    287
    288		skcipher_request_set_tfm(&subreq, tfm);
    289		skcipher_request_set_callback(&subreq,
    290					      skcipher_request_flags(req),
    291					      NULL, NULL);
    292		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    293					   xts_blocks * AES_BLOCK_SIZE,
    294					   req->iv);
    295		req = &subreq;
    296	} else {
    297		tail = 0;
    298	}
    299
    300	err = skcipher_walk_virt(&walk, req, false);
    301	if (err)
    302		return err;
    303
    304	while (walk.nbytes >= AES_BLOCK_SIZE) {
    305		int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
    306		out = walk.dst.virt.addr;
    307		in = walk.src.virt.addr;
    308		nbytes = walk.nbytes;
    309
    310		kernel_neon_begin();
    311		if (blocks >= 8) {
    312			if (first == 1)
    313				neon_aes_ecb_encrypt(walk.iv, walk.iv,
    314						     ctx->twkey,
    315						     ctx->key.rounds, 1);
    316			first = 2;
    317
    318			fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
    319			   walk.iv);
    320
    321			out += blocks * AES_BLOCK_SIZE;
    322			in += blocks * AES_BLOCK_SIZE;
    323			nbytes -= blocks * AES_BLOCK_SIZE;
    324		}
    325		if (walk.nbytes == walk.total && nbytes > 0) {
    326			if (encrypt)
    327				neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
    328						     ctx->key.rounds, nbytes,
    329						     ctx->twkey, walk.iv, first);
    330			else
    331				neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
    332						     ctx->key.rounds, nbytes,
    333						     ctx->twkey, walk.iv, first);
    334			nbytes = first = 0;
    335		}
    336		kernel_neon_end();
    337		err = skcipher_walk_done(&walk, nbytes);
    338	}
    339
    340	if (err || likely(!tail))
    341		return err;
    342
    343	/* handle ciphertext stealing */
    344	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
    345	if (req->dst != req->src)
    346		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
    347
    348	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
    349				   req->iv);
    350
    351	err = skcipher_walk_virt(&walk, req, false);
    352	if (err)
    353		return err;
    354
    355	out = walk.dst.virt.addr;
    356	in = walk.src.virt.addr;
    357	nbytes = walk.nbytes;
    358
    359	kernel_neon_begin();
    360	if (encrypt)
    361		neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
    362				     nbytes, ctx->twkey, walk.iv, first);
    363	else
    364		neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
    365				     nbytes, ctx->twkey, walk.iv, first);
    366	kernel_neon_end();
    367
    368	return skcipher_walk_done(&walk, 0);
    369}
    370
    371static int xts_encrypt(struct skcipher_request *req)
    372{
    373	return __xts_crypt(req, true, aesbs_xts_encrypt);
    374}
    375
    376static int xts_decrypt(struct skcipher_request *req)
    377{
    378	return __xts_crypt(req, false, aesbs_xts_decrypt);
    379}
    380
    381static struct skcipher_alg aes_algs[] = { {
    382	.base.cra_name		= "ecb(aes)",
    383	.base.cra_driver_name	= "ecb-aes-neonbs",
    384	.base.cra_priority	= 250,
    385	.base.cra_blocksize	= AES_BLOCK_SIZE,
    386	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
    387	.base.cra_module	= THIS_MODULE,
    388
    389	.min_keysize		= AES_MIN_KEY_SIZE,
    390	.max_keysize		= AES_MAX_KEY_SIZE,
    391	.walksize		= 8 * AES_BLOCK_SIZE,
    392	.setkey			= aesbs_setkey,
    393	.encrypt		= ecb_encrypt,
    394	.decrypt		= ecb_decrypt,
    395}, {
    396	.base.cra_name		= "cbc(aes)",
    397	.base.cra_driver_name	= "cbc-aes-neonbs",
    398	.base.cra_priority	= 250,
    399	.base.cra_blocksize	= AES_BLOCK_SIZE,
    400	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctr_ctx),
    401	.base.cra_module	= THIS_MODULE,
    402
    403	.min_keysize		= AES_MIN_KEY_SIZE,
    404	.max_keysize		= AES_MAX_KEY_SIZE,
    405	.walksize		= 8 * AES_BLOCK_SIZE,
    406	.ivsize			= AES_BLOCK_SIZE,
    407	.setkey			= aesbs_cbc_ctr_setkey,
    408	.encrypt		= cbc_encrypt,
    409	.decrypt		= cbc_decrypt,
    410}, {
    411	.base.cra_name		= "ctr(aes)",
    412	.base.cra_driver_name	= "ctr-aes-neonbs",
    413	.base.cra_priority	= 250,
    414	.base.cra_blocksize	= 1,
    415	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctr_ctx),
    416	.base.cra_module	= THIS_MODULE,
    417
    418	.min_keysize		= AES_MIN_KEY_SIZE,
    419	.max_keysize		= AES_MAX_KEY_SIZE,
    420	.chunksize		= AES_BLOCK_SIZE,
    421	.walksize		= 8 * AES_BLOCK_SIZE,
    422	.ivsize			= AES_BLOCK_SIZE,
    423	.setkey			= aesbs_cbc_ctr_setkey,
    424	.encrypt		= ctr_encrypt,
    425	.decrypt		= ctr_encrypt,
    426}, {
    427	.base.cra_name		= "xts(aes)",
    428	.base.cra_driver_name	= "xts-aes-neonbs",
    429	.base.cra_priority	= 250,
    430	.base.cra_blocksize	= AES_BLOCK_SIZE,
    431	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
    432	.base.cra_module	= THIS_MODULE,
    433
    434	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
    435	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
    436	.walksize		= 8 * AES_BLOCK_SIZE,
    437	.ivsize			= AES_BLOCK_SIZE,
    438	.setkey			= aesbs_xts_setkey,
    439	.encrypt		= xts_encrypt,
    440	.decrypt		= xts_decrypt,
    441} };
    442
    443static void aes_exit(void)
    444{
    445	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    446}
    447
    448static int __init aes_init(void)
    449{
    450	if (!cpu_have_named_feature(ASIMD))
    451		return -ENODEV;
    452
    453	return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    454}
    455
    456module_init(aes_init);
    457module_exit(aes_exit);