cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sm4-neon-glue.c (11034B)


      1/* SPDX-License-Identifier: GPL-2.0-or-later */
      2/*
      3 * SM4 Cipher Algorithm, using ARMv8 NEON
      4 * as specified in
      5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
      6 *
      7 * Copyright (C) 2022, Alibaba Group.
      8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
      9 */
     10
     11#include <linux/module.h>
     12#include <linux/crypto.h>
     13#include <linux/kernel.h>
     14#include <linux/cpufeature.h>
     15#include <asm/neon.h>
     16#include <asm/simd.h>
     17#include <crypto/internal/simd.h>
     18#include <crypto/internal/skcipher.h>
     19#include <crypto/sm4.h>
     20
     21#define BYTES2BLKS(nbytes)	((nbytes) >> 4)
     22#define BYTES2BLK8(nbytes)	(((nbytes) >> 4) & ~(8 - 1))
     23
     24asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
     25				      unsigned int nblks);
     26asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
     27				    unsigned int nblks);
     28asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
     29				      u8 *iv, unsigned int nblks);
     30asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
     31				      u8 *iv, unsigned int nblks);
     32asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
     33				      u8 *iv, unsigned int nblks);
     34
     35static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
     36		      unsigned int key_len)
     37{
     38	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     39
     40	return sm4_expandkey(ctx, key, key_len);
     41}
     42
     43static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
     44{
     45	struct skcipher_walk walk;
     46	unsigned int nbytes;
     47	int err;
     48
     49	err = skcipher_walk_virt(&walk, req, false);
     50
     51	while ((nbytes = walk.nbytes) > 0) {
     52		const u8 *src = walk.src.virt.addr;
     53		u8 *dst = walk.dst.virt.addr;
     54		unsigned int nblks;
     55
     56		kernel_neon_begin();
     57
     58		nblks = BYTES2BLK8(nbytes);
     59		if (nblks) {
     60			sm4_neon_crypt_blk8(rkey, dst, src, nblks);
     61			dst += nblks * SM4_BLOCK_SIZE;
     62			src += nblks * SM4_BLOCK_SIZE;
     63			nbytes -= nblks * SM4_BLOCK_SIZE;
     64		}
     65
     66		nblks = BYTES2BLKS(nbytes);
     67		if (nblks) {
     68			sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
     69			nbytes -= nblks * SM4_BLOCK_SIZE;
     70		}
     71
     72		kernel_neon_end();
     73
     74		err = skcipher_walk_done(&walk, nbytes);
     75	}
     76
     77	return err;
     78}
     79
     80static int sm4_ecb_encrypt(struct skcipher_request *req)
     81{
     82	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     83	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     84
     85	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
     86}
     87
     88static int sm4_ecb_decrypt(struct skcipher_request *req)
     89{
     90	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     91	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     92
     93	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
     94}
     95
     96static int sm4_cbc_encrypt(struct skcipher_request *req)
     97{
     98	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     99	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    100	struct skcipher_walk walk;
    101	unsigned int nbytes;
    102	int err;
    103
    104	err = skcipher_walk_virt(&walk, req, false);
    105
    106	while ((nbytes = walk.nbytes) > 0) {
    107		const u8 *iv = walk.iv;
    108		const u8 *src = walk.src.virt.addr;
    109		u8 *dst = walk.dst.virt.addr;
    110
    111		while (nbytes >= SM4_BLOCK_SIZE) {
    112			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
    113			sm4_crypt_block(ctx->rkey_enc, dst, dst);
    114			iv = dst;
    115			src += SM4_BLOCK_SIZE;
    116			dst += SM4_BLOCK_SIZE;
    117			nbytes -= SM4_BLOCK_SIZE;
    118		}
    119		if (iv != walk.iv)
    120			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    121
    122		err = skcipher_walk_done(&walk, nbytes);
    123	}
    124
    125	return err;
    126}
    127
    128static int sm4_cbc_decrypt(struct skcipher_request *req)
    129{
    130	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    131	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    132	struct skcipher_walk walk;
    133	unsigned int nbytes;
    134	int err;
    135
    136	err = skcipher_walk_virt(&walk, req, false);
    137
    138	while ((nbytes = walk.nbytes) > 0) {
    139		const u8 *src = walk.src.virt.addr;
    140		u8 *dst = walk.dst.virt.addr;
    141		unsigned int nblks;
    142
    143		kernel_neon_begin();
    144
    145		nblks = BYTES2BLK8(nbytes);
    146		if (nblks) {
    147			sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
    148					walk.iv, nblks);
    149			dst += nblks * SM4_BLOCK_SIZE;
    150			src += nblks * SM4_BLOCK_SIZE;
    151			nbytes -= nblks * SM4_BLOCK_SIZE;
    152		}
    153
    154		nblks = BYTES2BLKS(nbytes);
    155		if (nblks) {
    156			u8 keystream[SM4_BLOCK_SIZE * 8];
    157			u8 iv[SM4_BLOCK_SIZE];
    158			int i;
    159
    160			sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
    161					src, nblks);
    162
    163			src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
    164			dst += (nblks - 1) * SM4_BLOCK_SIZE;
    165			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
    166
    167			for (i = nblks - 1; i > 0; i--) {
    168				crypto_xor_cpy(dst, src,
    169					&keystream[i * SM4_BLOCK_SIZE],
    170					SM4_BLOCK_SIZE);
    171				src -= SM4_BLOCK_SIZE;
    172				dst -= SM4_BLOCK_SIZE;
    173			}
    174			crypto_xor_cpy(dst, walk.iv,
    175					keystream, SM4_BLOCK_SIZE);
    176			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    177			nbytes -= nblks * SM4_BLOCK_SIZE;
    178		}
    179
    180		kernel_neon_end();
    181
    182		err = skcipher_walk_done(&walk, nbytes);
    183	}
    184
    185	return err;
    186}
    187
    188static int sm4_cfb_encrypt(struct skcipher_request *req)
    189{
    190	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    191	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    192	struct skcipher_walk walk;
    193	unsigned int nbytes;
    194	int err;
    195
    196	err = skcipher_walk_virt(&walk, req, false);
    197
    198	while ((nbytes = walk.nbytes) > 0) {
    199		u8 keystream[SM4_BLOCK_SIZE];
    200		const u8 *iv = walk.iv;
    201		const u8 *src = walk.src.virt.addr;
    202		u8 *dst = walk.dst.virt.addr;
    203
    204		while (nbytes >= SM4_BLOCK_SIZE) {
    205			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
    206			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
    207			iv = dst;
    208			src += SM4_BLOCK_SIZE;
    209			dst += SM4_BLOCK_SIZE;
    210			nbytes -= SM4_BLOCK_SIZE;
    211		}
    212		if (iv != walk.iv)
    213			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    214
    215		/* tail */
    216		if (walk.nbytes == walk.total && nbytes > 0) {
    217			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    218			crypto_xor_cpy(dst, src, keystream, nbytes);
    219			nbytes = 0;
    220		}
    221
    222		err = skcipher_walk_done(&walk, nbytes);
    223	}
    224
    225	return err;
    226}
    227
    228static int sm4_cfb_decrypt(struct skcipher_request *req)
    229{
    230	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    231	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    232	struct skcipher_walk walk;
    233	unsigned int nbytes;
    234	int err;
    235
    236	err = skcipher_walk_virt(&walk, req, false);
    237
    238	while ((nbytes = walk.nbytes) > 0) {
    239		const u8 *src = walk.src.virt.addr;
    240		u8 *dst = walk.dst.virt.addr;
    241		unsigned int nblks;
    242
    243		kernel_neon_begin();
    244
    245		nblks = BYTES2BLK8(nbytes);
    246		if (nblks) {
    247			sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
    248					walk.iv, nblks);
    249			dst += nblks * SM4_BLOCK_SIZE;
    250			src += nblks * SM4_BLOCK_SIZE;
    251			nbytes -= nblks * SM4_BLOCK_SIZE;
    252		}
    253
    254		nblks = BYTES2BLKS(nbytes);
    255		if (nblks) {
    256			u8 keystream[SM4_BLOCK_SIZE * 8];
    257
    258			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
    259			if (nblks > 1)
    260				memcpy(&keystream[SM4_BLOCK_SIZE], src,
    261					(nblks - 1) * SM4_BLOCK_SIZE);
    262			memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
    263				SM4_BLOCK_SIZE);
    264
    265			sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
    266					keystream, nblks);
    267
    268			crypto_xor_cpy(dst, src, keystream,
    269					nblks * SM4_BLOCK_SIZE);
    270			dst += nblks * SM4_BLOCK_SIZE;
    271			src += nblks * SM4_BLOCK_SIZE;
    272			nbytes -= nblks * SM4_BLOCK_SIZE;
    273		}
    274
    275		kernel_neon_end();
    276
    277		/* tail */
    278		if (walk.nbytes == walk.total && nbytes > 0) {
    279			u8 keystream[SM4_BLOCK_SIZE];
    280
    281			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    282			crypto_xor_cpy(dst, src, keystream, nbytes);
    283			nbytes = 0;
    284		}
    285
    286		err = skcipher_walk_done(&walk, nbytes);
    287	}
    288
    289	return err;
    290}
    291
    292static int sm4_ctr_crypt(struct skcipher_request *req)
    293{
    294	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    295	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    296	struct skcipher_walk walk;
    297	unsigned int nbytes;
    298	int err;
    299
    300	err = skcipher_walk_virt(&walk, req, false);
    301
    302	while ((nbytes = walk.nbytes) > 0) {
    303		const u8 *src = walk.src.virt.addr;
    304		u8 *dst = walk.dst.virt.addr;
    305		unsigned int nblks;
    306
    307		kernel_neon_begin();
    308
    309		nblks = BYTES2BLK8(nbytes);
    310		if (nblks) {
    311			sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
    312					walk.iv, nblks);
    313			dst += nblks * SM4_BLOCK_SIZE;
    314			src += nblks * SM4_BLOCK_SIZE;
    315			nbytes -= nblks * SM4_BLOCK_SIZE;
    316		}
    317
    318		nblks = BYTES2BLKS(nbytes);
    319		if (nblks) {
    320			u8 keystream[SM4_BLOCK_SIZE * 8];
    321			int i;
    322
    323			for (i = 0; i < nblks; i++) {
    324				memcpy(&keystream[i * SM4_BLOCK_SIZE],
    325					walk.iv, SM4_BLOCK_SIZE);
    326				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
    327			}
    328			sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
    329					keystream, nblks);
    330
    331			crypto_xor_cpy(dst, src, keystream,
    332					nblks * SM4_BLOCK_SIZE);
    333			dst += nblks * SM4_BLOCK_SIZE;
    334			src += nblks * SM4_BLOCK_SIZE;
    335			nbytes -= nblks * SM4_BLOCK_SIZE;
    336		}
    337
    338		kernel_neon_end();
    339
    340		/* tail */
    341		if (walk.nbytes == walk.total && nbytes > 0) {
    342			u8 keystream[SM4_BLOCK_SIZE];
    343
    344			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    345			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
    346			crypto_xor_cpy(dst, src, keystream, nbytes);
    347			nbytes = 0;
    348		}
    349
    350		err = skcipher_walk_done(&walk, nbytes);
    351	}
    352
    353	return err;
    354}
    355
    356static struct skcipher_alg sm4_algs[] = {
    357	{
    358		.base = {
    359			.cra_name		= "ecb(sm4)",
    360			.cra_driver_name	= "ecb-sm4-neon",
    361			.cra_priority		= 200,
    362			.cra_blocksize		= SM4_BLOCK_SIZE,
    363			.cra_ctxsize		= sizeof(struct sm4_ctx),
    364			.cra_module		= THIS_MODULE,
    365		},
    366		.min_keysize	= SM4_KEY_SIZE,
    367		.max_keysize	= SM4_KEY_SIZE,
    368		.setkey		= sm4_setkey,
    369		.encrypt	= sm4_ecb_encrypt,
    370		.decrypt	= sm4_ecb_decrypt,
    371	}, {
    372		.base = {
    373			.cra_name		= "cbc(sm4)",
    374			.cra_driver_name	= "cbc-sm4-neon",
    375			.cra_priority		= 200,
    376			.cra_blocksize		= SM4_BLOCK_SIZE,
    377			.cra_ctxsize		= sizeof(struct sm4_ctx),
    378			.cra_module		= THIS_MODULE,
    379		},
    380		.min_keysize	= SM4_KEY_SIZE,
    381		.max_keysize	= SM4_KEY_SIZE,
    382		.ivsize		= SM4_BLOCK_SIZE,
    383		.setkey		= sm4_setkey,
    384		.encrypt	= sm4_cbc_encrypt,
    385		.decrypt	= sm4_cbc_decrypt,
    386	}, {
    387		.base = {
    388			.cra_name		= "cfb(sm4)",
    389			.cra_driver_name	= "cfb-sm4-neon",
    390			.cra_priority		= 200,
    391			.cra_blocksize		= 1,
    392			.cra_ctxsize		= sizeof(struct sm4_ctx),
    393			.cra_module		= THIS_MODULE,
    394		},
    395		.min_keysize	= SM4_KEY_SIZE,
    396		.max_keysize	= SM4_KEY_SIZE,
    397		.ivsize		= SM4_BLOCK_SIZE,
    398		.chunksize	= SM4_BLOCK_SIZE,
    399		.setkey		= sm4_setkey,
    400		.encrypt	= sm4_cfb_encrypt,
    401		.decrypt	= sm4_cfb_decrypt,
    402	}, {
    403		.base = {
    404			.cra_name		= "ctr(sm4)",
    405			.cra_driver_name	= "ctr-sm4-neon",
    406			.cra_priority		= 200,
    407			.cra_blocksize		= 1,
    408			.cra_ctxsize		= sizeof(struct sm4_ctx),
    409			.cra_module		= THIS_MODULE,
    410		},
    411		.min_keysize	= SM4_KEY_SIZE,
    412		.max_keysize	= SM4_KEY_SIZE,
    413		.ivsize		= SM4_BLOCK_SIZE,
    414		.chunksize	= SM4_BLOCK_SIZE,
    415		.setkey		= sm4_setkey,
    416		.encrypt	= sm4_ctr_crypt,
    417		.decrypt	= sm4_ctr_crypt,
    418	}
    419};
    420
    421static int __init sm4_init(void)
    422{
    423	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
    424}
    425
    426static void __exit sm4_exit(void)
    427{
    428	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
    429}
    430
    431module_init(sm4_init);
    432module_exit(sm4_exit);
    433
    434MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON");
    435MODULE_ALIAS_CRYPTO("sm4-neon");
    436MODULE_ALIAS_CRYPTO("sm4");
    437MODULE_ALIAS_CRYPTO("ecb(sm4)");
    438MODULE_ALIAS_CRYPTO("cbc(sm4)");
    439MODULE_ALIAS_CRYPTO("cfb(sm4)");
    440MODULE_ALIAS_CRYPTO("ctr(sm4)");
    441MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
    442MODULE_LICENSE("GPL v2");