cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

aes-neonbs-glue.c (14517B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Bit sliced AES using NEON instructions
      4 *
      5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
      6 */
      7
      8#include <asm/neon.h>
      9#include <asm/simd.h>
     10#include <crypto/aes.h>
     11#include <crypto/ctr.h>
     12#include <crypto/internal/cipher.h>
     13#include <crypto/internal/simd.h>
     14#include <crypto/internal/skcipher.h>
     15#include <crypto/scatterwalk.h>
     16#include <crypto/xts.h>
     17#include <linux/module.h>
     18
     19MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
     20MODULE_LICENSE("GPL v2");
     21
     22MODULE_ALIAS_CRYPTO("ecb(aes)");
     23MODULE_ALIAS_CRYPTO("cbc(aes)-all");
     24MODULE_ALIAS_CRYPTO("ctr(aes)");
     25MODULE_ALIAS_CRYPTO("xts(aes)");
     26
     27MODULE_IMPORT_NS(CRYPTO_INTERNAL);
     28
     29asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
     30
     31asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
     32				  int rounds, int blocks);
     33asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
     34				  int rounds, int blocks);
     35
     36asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
     37				  int rounds, int blocks, u8 iv[]);
     38
     39asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
     40				  int rounds, int blocks, u8 ctr[]);
     41
     42asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
     43				  int rounds, int blocks, u8 iv[], int);
     44asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
     45				  int rounds, int blocks, u8 iv[], int);
     46
     47struct aesbs_ctx {
     48	int	rounds;
     49	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
     50};
     51
     52struct aesbs_cbc_ctx {
     53	struct aesbs_ctx	key;
     54	struct crypto_skcipher	*enc_tfm;
     55};
     56
     57struct aesbs_xts_ctx {
     58	struct aesbs_ctx	key;
     59	struct crypto_cipher	*cts_tfm;
     60	struct crypto_cipher	*tweak_tfm;
     61};
     62
     63struct aesbs_ctr_ctx {
     64	struct aesbs_ctx	key;		/* must be first member */
     65	struct crypto_aes_ctx	fallback;
     66};
     67
     68static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
     69			unsigned int key_len)
     70{
     71	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
     72	struct crypto_aes_ctx rk;
     73	int err;
     74
     75	err = aes_expandkey(&rk, in_key, key_len);
     76	if (err)
     77		return err;
     78
     79	ctx->rounds = 6 + key_len / 4;
     80
     81	kernel_neon_begin();
     82	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
     83	kernel_neon_end();
     84
     85	return 0;
     86}
     87
     88static int __ecb_crypt(struct skcipher_request *req,
     89		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
     90				  int rounds, int blocks))
     91{
     92	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     93	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
     94	struct skcipher_walk walk;
     95	int err;
     96
     97	err = skcipher_walk_virt(&walk, req, false);
     98
     99	while (walk.nbytes >= AES_BLOCK_SIZE) {
    100		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    101
    102		if (walk.nbytes < walk.total)
    103			blocks = round_down(blocks,
    104					    walk.stride / AES_BLOCK_SIZE);
    105
    106		kernel_neon_begin();
    107		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
    108		   ctx->rounds, blocks);
    109		kernel_neon_end();
    110		err = skcipher_walk_done(&walk,
    111					 walk.nbytes - blocks * AES_BLOCK_SIZE);
    112	}
    113
    114	return err;
    115}
    116
    117static int ecb_encrypt(struct skcipher_request *req)
    118{
    119	return __ecb_crypt(req, aesbs_ecb_encrypt);
    120}
    121
    122static int ecb_decrypt(struct skcipher_request *req)
    123{
    124	return __ecb_crypt(req, aesbs_ecb_decrypt);
    125}
    126
    127static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    128			    unsigned int key_len)
    129{
    130	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    131	struct crypto_aes_ctx rk;
    132	int err;
    133
    134	err = aes_expandkey(&rk, in_key, key_len);
    135	if (err)
    136		return err;
    137
    138	ctx->key.rounds = 6 + key_len / 4;
    139
    140	kernel_neon_begin();
    141	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
    142	kernel_neon_end();
    143	memzero_explicit(&rk, sizeof(rk));
    144
    145	return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
    146}
    147
    148static int cbc_encrypt(struct skcipher_request *req)
    149{
    150	struct skcipher_request *subreq = skcipher_request_ctx(req);
    151	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    152	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    153
    154	skcipher_request_set_tfm(subreq, ctx->enc_tfm);
    155	skcipher_request_set_callback(subreq,
    156				      skcipher_request_flags(req),
    157				      NULL, NULL);
    158	skcipher_request_set_crypt(subreq, req->src, req->dst,
    159				   req->cryptlen, req->iv);
    160
    161	return crypto_skcipher_encrypt(subreq);
    162}
    163
    164static int cbc_decrypt(struct skcipher_request *req)
    165{
    166	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    167	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    168	struct skcipher_walk walk;
    169	int err;
    170
    171	err = skcipher_walk_virt(&walk, req, false);
    172
    173	while (walk.nbytes >= AES_BLOCK_SIZE) {
    174		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    175
    176		if (walk.nbytes < walk.total)
    177			blocks = round_down(blocks,
    178					    walk.stride / AES_BLOCK_SIZE);
    179
    180		kernel_neon_begin();
    181		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
    182				  ctx->key.rk, ctx->key.rounds, blocks,
    183				  walk.iv);
    184		kernel_neon_end();
    185		err = skcipher_walk_done(&walk,
    186					 walk.nbytes - blocks * AES_BLOCK_SIZE);
    187	}
    188
    189	return err;
    190}
    191
    192static int cbc_init(struct crypto_skcipher *tfm)
    193{
    194	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    195	unsigned int reqsize;
    196
    197	ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
    198					     CRYPTO_ALG_NEED_FALLBACK);
    199	if (IS_ERR(ctx->enc_tfm))
    200		return PTR_ERR(ctx->enc_tfm);
    201
    202	reqsize = sizeof(struct skcipher_request);
    203	reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
    204	crypto_skcipher_set_reqsize(tfm, reqsize);
    205
    206	return 0;
    207}
    208
    209static void cbc_exit(struct crypto_skcipher *tfm)
    210{
    211	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
    212
    213	crypto_free_skcipher(ctx->enc_tfm);
    214}
    215
    216static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
    217				 unsigned int key_len)
    218{
    219	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    220	int err;
    221
    222	err = aes_expandkey(&ctx->fallback, in_key, key_len);
    223	if (err)
    224		return err;
    225
    226	ctx->key.rounds = 6 + key_len / 4;
    227
    228	kernel_neon_begin();
    229	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
    230	kernel_neon_end();
    231
    232	return 0;
    233}
    234
    235static int ctr_encrypt(struct skcipher_request *req)
    236{
    237	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    238	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
    239	struct skcipher_walk walk;
    240	u8 buf[AES_BLOCK_SIZE];
    241	int err;
    242
    243	err = skcipher_walk_virt(&walk, req, false);
    244
    245	while (walk.nbytes > 0) {
    246		const u8 *src = walk.src.virt.addr;
    247		u8 *dst = walk.dst.virt.addr;
    248		int bytes = walk.nbytes;
    249
    250		if (unlikely(bytes < AES_BLOCK_SIZE))
    251			src = dst = memcpy(buf + sizeof(buf) - bytes,
    252					   src, bytes);
    253		else if (walk.nbytes < walk.total)
    254			bytes &= ~(8 * AES_BLOCK_SIZE - 1);
    255
    256		kernel_neon_begin();
    257		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
    258		kernel_neon_end();
    259
    260		if (unlikely(bytes < AES_BLOCK_SIZE))
    261			memcpy(walk.dst.virt.addr,
    262			       buf + sizeof(buf) - bytes, bytes);
    263
    264		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
    265	}
    266
    267	return err;
    268}
    269
    270static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
    271{
    272	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
    273	unsigned long flags;
    274
    275	/*
    276	 * Temporarily disable interrupts to avoid races where
    277	 * cachelines are evicted when the CPU is interrupted
    278	 * to do something else.
    279	 */
    280	local_irq_save(flags);
    281	aes_encrypt(&ctx->fallback, dst, src);
    282	local_irq_restore(flags);
    283}
    284
    285static int ctr_encrypt_sync(struct skcipher_request *req)
    286{
    287	if (!crypto_simd_usable())
    288		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
    289
    290	return ctr_encrypt(req);
    291}
    292
    293static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
    294			    unsigned int key_len)
    295{
    296	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    297	int err;
    298
    299	err = xts_verify_key(tfm, in_key, key_len);
    300	if (err)
    301		return err;
    302
    303	key_len /= 2;
    304	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
    305	if (err)
    306		return err;
    307	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
    308	if (err)
    309		return err;
    310
    311	return aesbs_setkey(tfm, in_key, key_len);
    312}
    313
    314static int xts_init(struct crypto_skcipher *tfm)
    315{
    316	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    317
    318	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
    319	if (IS_ERR(ctx->cts_tfm))
    320		return PTR_ERR(ctx->cts_tfm);
    321
    322	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
    323	if (IS_ERR(ctx->tweak_tfm))
    324		crypto_free_cipher(ctx->cts_tfm);
    325
    326	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
    327}
    328
    329static void xts_exit(struct crypto_skcipher *tfm)
    330{
    331	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    332
    333	crypto_free_cipher(ctx->tweak_tfm);
    334	crypto_free_cipher(ctx->cts_tfm);
    335}
    336
    337static int __xts_crypt(struct skcipher_request *req, bool encrypt,
    338		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
    339				  int rounds, int blocks, u8 iv[], int))
    340{
    341	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    342	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
    343	int tail = req->cryptlen % AES_BLOCK_SIZE;
    344	struct skcipher_request subreq;
    345	u8 buf[2 * AES_BLOCK_SIZE];
    346	struct skcipher_walk walk;
    347	int err;
    348
    349	if (req->cryptlen < AES_BLOCK_SIZE)
    350		return -EINVAL;
    351
    352	if (unlikely(tail)) {
    353		skcipher_request_set_tfm(&subreq, tfm);
    354		skcipher_request_set_callback(&subreq,
    355					      skcipher_request_flags(req),
    356					      NULL, NULL);
    357		skcipher_request_set_crypt(&subreq, req->src, req->dst,
    358					   req->cryptlen - tail, req->iv);
    359		req = &subreq;
    360	}
    361
    362	err = skcipher_walk_virt(&walk, req, true);
    363	if (err)
    364		return err;
    365
    366	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
    367
    368	while (walk.nbytes >= AES_BLOCK_SIZE) {
    369		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
    370		int reorder_last_tweak = !encrypt && tail > 0;
    371
    372		if (walk.nbytes < walk.total) {
    373			blocks = round_down(blocks,
    374					    walk.stride / AES_BLOCK_SIZE);
    375			reorder_last_tweak = 0;
    376		}
    377
    378		kernel_neon_begin();
    379		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
    380		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
    381		kernel_neon_end();
    382		err = skcipher_walk_done(&walk,
    383					 walk.nbytes - blocks * AES_BLOCK_SIZE);
    384	}
    385
    386	if (err || likely(!tail))
    387		return err;
    388
    389	/* handle ciphertext stealing */
    390	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
    391				 AES_BLOCK_SIZE, 0);
    392	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
    393	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
    394
    395	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
    396
    397	if (encrypt)
    398		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
    399	else
    400		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
    401
    402	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
    403
    404	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
    405				 AES_BLOCK_SIZE + tail, 1);
    406	return 0;
    407}
    408
    409static int xts_encrypt(struct skcipher_request *req)
    410{
    411	return __xts_crypt(req, true, aesbs_xts_encrypt);
    412}
    413
    414static int xts_decrypt(struct skcipher_request *req)
    415{
    416	return __xts_crypt(req, false, aesbs_xts_decrypt);
    417}
    418
    419static struct skcipher_alg aes_algs[] = { {
    420	.base.cra_name		= "__ecb(aes)",
    421	.base.cra_driver_name	= "__ecb-aes-neonbs",
    422	.base.cra_priority	= 250,
    423	.base.cra_blocksize	= AES_BLOCK_SIZE,
    424	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
    425	.base.cra_module	= THIS_MODULE,
    426	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    427
    428	.min_keysize		= AES_MIN_KEY_SIZE,
    429	.max_keysize		= AES_MAX_KEY_SIZE,
    430	.walksize		= 8 * AES_BLOCK_SIZE,
    431	.setkey			= aesbs_setkey,
    432	.encrypt		= ecb_encrypt,
    433	.decrypt		= ecb_decrypt,
    434}, {
    435	.base.cra_name		= "__cbc(aes)",
    436	.base.cra_driver_name	= "__cbc-aes-neonbs",
    437	.base.cra_priority	= 250,
    438	.base.cra_blocksize	= AES_BLOCK_SIZE,
    439	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
    440	.base.cra_module	= THIS_MODULE,
    441	.base.cra_flags		= CRYPTO_ALG_INTERNAL |
    442				  CRYPTO_ALG_NEED_FALLBACK,
    443
    444	.min_keysize		= AES_MIN_KEY_SIZE,
    445	.max_keysize		= AES_MAX_KEY_SIZE,
    446	.walksize		= 8 * AES_BLOCK_SIZE,
    447	.ivsize			= AES_BLOCK_SIZE,
    448	.setkey			= aesbs_cbc_setkey,
    449	.encrypt		= cbc_encrypt,
    450	.decrypt		= cbc_decrypt,
    451	.init			= cbc_init,
    452	.exit			= cbc_exit,
    453}, {
    454	.base.cra_name		= "__ctr(aes)",
    455	.base.cra_driver_name	= "__ctr-aes-neonbs",
    456	.base.cra_priority	= 250,
    457	.base.cra_blocksize	= 1,
    458	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
    459	.base.cra_module	= THIS_MODULE,
    460	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    461
    462	.min_keysize		= AES_MIN_KEY_SIZE,
    463	.max_keysize		= AES_MAX_KEY_SIZE,
    464	.chunksize		= AES_BLOCK_SIZE,
    465	.walksize		= 8 * AES_BLOCK_SIZE,
    466	.ivsize			= AES_BLOCK_SIZE,
    467	.setkey			= aesbs_setkey,
    468	.encrypt		= ctr_encrypt,
    469	.decrypt		= ctr_encrypt,
    470}, {
    471	.base.cra_name		= "ctr(aes)",
    472	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
    473	.base.cra_priority	= 250 - 1,
    474	.base.cra_blocksize	= 1,
    475	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
    476	.base.cra_module	= THIS_MODULE,
    477
    478	.min_keysize		= AES_MIN_KEY_SIZE,
    479	.max_keysize		= AES_MAX_KEY_SIZE,
    480	.chunksize		= AES_BLOCK_SIZE,
    481	.walksize		= 8 * AES_BLOCK_SIZE,
    482	.ivsize			= AES_BLOCK_SIZE,
    483	.setkey			= aesbs_ctr_setkey_sync,
    484	.encrypt		= ctr_encrypt_sync,
    485	.decrypt		= ctr_encrypt_sync,
    486}, {
    487	.base.cra_name		= "__xts(aes)",
    488	.base.cra_driver_name	= "__xts-aes-neonbs",
    489	.base.cra_priority	= 250,
    490	.base.cra_blocksize	= AES_BLOCK_SIZE,
    491	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
    492	.base.cra_module	= THIS_MODULE,
    493	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
    494
    495	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
    496	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
    497	.walksize		= 8 * AES_BLOCK_SIZE,
    498	.ivsize			= AES_BLOCK_SIZE,
    499	.setkey			= aesbs_xts_setkey,
    500	.encrypt		= xts_encrypt,
    501	.decrypt		= xts_decrypt,
    502	.init			= xts_init,
    503	.exit			= xts_exit,
    504} };
    505
    506static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
    507
    508static void aes_exit(void)
    509{
    510	int i;
    511
    512	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
    513		if (aes_simd_algs[i])
    514			simd_skcipher_free(aes_simd_algs[i]);
    515
    516	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    517}
    518
    519static int __init aes_init(void)
    520{
    521	struct simd_skcipher_alg *simd;
    522	const char *basename;
    523	const char *algname;
    524	const char *drvname;
    525	int err;
    526	int i;
    527
    528	if (!(elf_hwcap & HWCAP_NEON))
    529		return -ENODEV;
    530
    531	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
    532	if (err)
    533		return err;
    534
    535	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
    536		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
    537			continue;
    538
    539		algname = aes_algs[i].base.cra_name + 2;
    540		drvname = aes_algs[i].base.cra_driver_name + 2;
    541		basename = aes_algs[i].base.cra_driver_name;
    542		simd = simd_skcipher_create_compat(algname, drvname, basename);
    543		err = PTR_ERR(simd);
    544		if (IS_ERR(simd))
    545			goto unregister_simds;
    546
    547		aes_simd_algs[i] = simd;
    548	}
    549	return 0;
    550
    551unregister_simds:
    552	aes_exit();
    553	return err;
    554}
    555
    556late_initcall(aes_init);
    557module_exit(aes_exit);