cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sm4-ce-glue.c (9225B)


      1/* SPDX-License-Identifier: GPL-2.0-or-later */
      2/*
      3 * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
      4 * as specified in
      5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
      6 *
      7 * Copyright (C) 2022, Alibaba Group.
      8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
      9 */
     10
     11#include <linux/module.h>
     12#include <linux/crypto.h>
     13#include <linux/kernel.h>
     14#include <linux/cpufeature.h>
     15#include <asm/neon.h>
     16#include <asm/simd.h>
     17#include <crypto/internal/simd.h>
     18#include <crypto/internal/skcipher.h>
     19#include <crypto/sm4.h>
     20
     21#define BYTES2BLKS(nbytes)	((nbytes) >> 4)
     22
     23asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
     24				  const u32 *fk, const u32 *ck);
     25asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
     26asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
     27			     unsigned int nblks);
     28asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
     29			       u8 *iv, unsigned int nblks);
     30asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
     31			       u8 *iv, unsigned int nblks);
     32asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
     33			       u8 *iv, unsigned int nblks);
     34asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
     35			       u8 *iv, unsigned int nblks);
     36asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
     37			       u8 *iv, unsigned int nblks);
     38
     39static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
     40		      unsigned int key_len)
     41{
     42	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     43
     44	if (key_len != SM4_KEY_SIZE)
     45		return -EINVAL;
     46
     47	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
     48			  crypto_sm4_fk, crypto_sm4_ck);
     49	return 0;
     50}
     51
     52static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
     53{
     54	struct skcipher_walk walk;
     55	unsigned int nbytes;
     56	int err;
     57
     58	err = skcipher_walk_virt(&walk, req, false);
     59
     60	while ((nbytes = walk.nbytes) > 0) {
     61		const u8 *src = walk.src.virt.addr;
     62		u8 *dst = walk.dst.virt.addr;
     63		unsigned int nblks;
     64
     65		kernel_neon_begin();
     66
     67		nblks = BYTES2BLKS(nbytes);
     68		if (nblks) {
     69			sm4_ce_crypt(rkey, dst, src, nblks);
     70			nbytes -= nblks * SM4_BLOCK_SIZE;
     71		}
     72
     73		kernel_neon_end();
     74
     75		err = skcipher_walk_done(&walk, nbytes);
     76	}
     77
     78	return err;
     79}
     80
     81static int sm4_ecb_encrypt(struct skcipher_request *req)
     82{
     83	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     84	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     85
     86	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
     87}
     88
     89static int sm4_ecb_decrypt(struct skcipher_request *req)
     90{
     91	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     92	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     93
     94	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
     95}
     96
     97static int sm4_cbc_encrypt(struct skcipher_request *req)
     98{
     99	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    100	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    101	struct skcipher_walk walk;
    102	unsigned int nbytes;
    103	int err;
    104
    105	err = skcipher_walk_virt(&walk, req, false);
    106
    107	while ((nbytes = walk.nbytes) > 0) {
    108		const u8 *src = walk.src.virt.addr;
    109		u8 *dst = walk.dst.virt.addr;
    110		unsigned int nblks;
    111
    112		kernel_neon_begin();
    113
    114		nblks = BYTES2BLKS(nbytes);
    115		if (nblks) {
    116			sm4_ce_cbc_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
    117			nbytes -= nblks * SM4_BLOCK_SIZE;
    118		}
    119
    120		kernel_neon_end();
    121
    122		err = skcipher_walk_done(&walk, nbytes);
    123	}
    124
    125	return err;
    126}
    127
    128static int sm4_cbc_decrypt(struct skcipher_request *req)
    129{
    130	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    131	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    132	struct skcipher_walk walk;
    133	unsigned int nbytes;
    134	int err;
    135
    136	err = skcipher_walk_virt(&walk, req, false);
    137
    138	while ((nbytes = walk.nbytes) > 0) {
    139		const u8 *src = walk.src.virt.addr;
    140		u8 *dst = walk.dst.virt.addr;
    141		unsigned int nblks;
    142
    143		kernel_neon_begin();
    144
    145		nblks = BYTES2BLKS(nbytes);
    146		if (nblks) {
    147			sm4_ce_cbc_dec(ctx->rkey_dec, dst, src, walk.iv, nblks);
    148			nbytes -= nblks * SM4_BLOCK_SIZE;
    149		}
    150
    151		kernel_neon_end();
    152
    153		err = skcipher_walk_done(&walk, nbytes);
    154	}
    155
    156	return err;
    157}
    158
    159static int sm4_cfb_encrypt(struct skcipher_request *req)
    160{
    161	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    162	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    163	struct skcipher_walk walk;
    164	unsigned int nbytes;
    165	int err;
    166
    167	err = skcipher_walk_virt(&walk, req, false);
    168
    169	while ((nbytes = walk.nbytes) > 0) {
    170		const u8 *src = walk.src.virt.addr;
    171		u8 *dst = walk.dst.virt.addr;
    172		unsigned int nblks;
    173
    174		kernel_neon_begin();
    175
    176		nblks = BYTES2BLKS(nbytes);
    177		if (nblks) {
    178			sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
    179			dst += nblks * SM4_BLOCK_SIZE;
    180			src += nblks * SM4_BLOCK_SIZE;
    181			nbytes -= nblks * SM4_BLOCK_SIZE;
    182		}
    183
    184		/* tail */
    185		if (walk.nbytes == walk.total && nbytes > 0) {
    186			u8 keystream[SM4_BLOCK_SIZE];
    187
    188			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    189			crypto_xor_cpy(dst, src, keystream, nbytes);
    190			nbytes = 0;
    191		}
    192
    193		kernel_neon_end();
    194
    195		err = skcipher_walk_done(&walk, nbytes);
    196	}
    197
    198	return err;
    199}
    200
    201static int sm4_cfb_decrypt(struct skcipher_request *req)
    202{
    203	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    204	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    205	struct skcipher_walk walk;
    206	unsigned int nbytes;
    207	int err;
    208
    209	err = skcipher_walk_virt(&walk, req, false);
    210
    211	while ((nbytes = walk.nbytes) > 0) {
    212		const u8 *src = walk.src.virt.addr;
    213		u8 *dst = walk.dst.virt.addr;
    214		unsigned int nblks;
    215
    216		kernel_neon_begin();
    217
    218		nblks = BYTES2BLKS(nbytes);
    219		if (nblks) {
    220			sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
    221			dst += nblks * SM4_BLOCK_SIZE;
    222			src += nblks * SM4_BLOCK_SIZE;
    223			nbytes -= nblks * SM4_BLOCK_SIZE;
    224		}
    225
    226		/* tail */
    227		if (walk.nbytes == walk.total && nbytes > 0) {
    228			u8 keystream[SM4_BLOCK_SIZE];
    229
    230			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    231			crypto_xor_cpy(dst, src, keystream, nbytes);
    232			nbytes = 0;
    233		}
    234
    235		kernel_neon_end();
    236
    237		err = skcipher_walk_done(&walk, nbytes);
    238	}
    239
    240	return err;
    241}
    242
    243static int sm4_ctr_crypt(struct skcipher_request *req)
    244{
    245	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    246	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    247	struct skcipher_walk walk;
    248	unsigned int nbytes;
    249	int err;
    250
    251	err = skcipher_walk_virt(&walk, req, false);
    252
    253	while ((nbytes = walk.nbytes) > 0) {
    254		const u8 *src = walk.src.virt.addr;
    255		u8 *dst = walk.dst.virt.addr;
    256		unsigned int nblks;
    257
    258		kernel_neon_begin();
    259
    260		nblks = BYTES2BLKS(nbytes);
    261		if (nblks) {
    262			sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
    263			dst += nblks * SM4_BLOCK_SIZE;
    264			src += nblks * SM4_BLOCK_SIZE;
    265			nbytes -= nblks * SM4_BLOCK_SIZE;
    266		}
    267
    268		/* tail */
    269		if (walk.nbytes == walk.total && nbytes > 0) {
    270			u8 keystream[SM4_BLOCK_SIZE];
    271
    272			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    273			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
    274			crypto_xor_cpy(dst, src, keystream, nbytes);
    275			nbytes = 0;
    276		}
    277
    278		kernel_neon_end();
    279
    280		err = skcipher_walk_done(&walk, nbytes);
    281	}
    282
    283	return err;
    284}
    285
    286static struct skcipher_alg sm4_algs[] = {
    287	{
    288		.base = {
    289			.cra_name		= "ecb(sm4)",
    290			.cra_driver_name	= "ecb-sm4-ce",
    291			.cra_priority		= 400,
    292			.cra_blocksize		= SM4_BLOCK_SIZE,
    293			.cra_ctxsize		= sizeof(struct sm4_ctx),
    294			.cra_module		= THIS_MODULE,
    295		},
    296		.min_keysize	= SM4_KEY_SIZE,
    297		.max_keysize	= SM4_KEY_SIZE,
    298		.setkey		= sm4_setkey,
    299		.encrypt	= sm4_ecb_encrypt,
    300		.decrypt	= sm4_ecb_decrypt,
    301	}, {
    302		.base = {
    303			.cra_name		= "cbc(sm4)",
    304			.cra_driver_name	= "cbc-sm4-ce",
    305			.cra_priority		= 400,
    306			.cra_blocksize		= SM4_BLOCK_SIZE,
    307			.cra_ctxsize		= sizeof(struct sm4_ctx),
    308			.cra_module		= THIS_MODULE,
    309		},
    310		.min_keysize	= SM4_KEY_SIZE,
    311		.max_keysize	= SM4_KEY_SIZE,
    312		.ivsize		= SM4_BLOCK_SIZE,
    313		.setkey		= sm4_setkey,
    314		.encrypt	= sm4_cbc_encrypt,
    315		.decrypt	= sm4_cbc_decrypt,
    316	}, {
    317		.base = {
    318			.cra_name		= "cfb(sm4)",
    319			.cra_driver_name	= "cfb-sm4-ce",
    320			.cra_priority		= 400,
    321			.cra_blocksize		= 1,
    322			.cra_ctxsize		= sizeof(struct sm4_ctx),
    323			.cra_module		= THIS_MODULE,
    324		},
    325		.min_keysize	= SM4_KEY_SIZE,
    326		.max_keysize	= SM4_KEY_SIZE,
    327		.ivsize		= SM4_BLOCK_SIZE,
    328		.chunksize	= SM4_BLOCK_SIZE,
    329		.setkey		= sm4_setkey,
    330		.encrypt	= sm4_cfb_encrypt,
    331		.decrypt	= sm4_cfb_decrypt,
    332	}, {
    333		.base = {
    334			.cra_name		= "ctr(sm4)",
    335			.cra_driver_name	= "ctr-sm4-ce",
    336			.cra_priority		= 400,
    337			.cra_blocksize		= 1,
    338			.cra_ctxsize		= sizeof(struct sm4_ctx),
    339			.cra_module		= THIS_MODULE,
    340		},
    341		.min_keysize	= SM4_KEY_SIZE,
    342		.max_keysize	= SM4_KEY_SIZE,
    343		.ivsize		= SM4_BLOCK_SIZE,
    344		.chunksize	= SM4_BLOCK_SIZE,
    345		.setkey		= sm4_setkey,
    346		.encrypt	= sm4_ctr_crypt,
    347		.decrypt	= sm4_ctr_crypt,
    348	}
    349};
    350
    351static int __init sm4_init(void)
    352{
    353	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
    354}
    355
    356static void __exit sm4_exit(void)
    357{
    358	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
    359}
    360
    361module_cpu_feature_match(SM4, sm4_init);
    362module_exit(sm4_exit);
    363
    364MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions");
    365MODULE_ALIAS_CRYPTO("sm4-ce");
    366MODULE_ALIAS_CRYPTO("sm4");
    367MODULE_ALIAS_CRYPTO("ecb(sm4)");
    368MODULE_ALIAS_CRYPTO("cbc(sm4)");
    369MODULE_ALIAS_CRYPTO("cfb(sm4)");
    370MODULE_ALIAS_CRYPTO("ctr(sm4)");
    371MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
    372MODULE_LICENSE("GPL v2");