cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sm4_aesni_avx_glue.c (12594B)


      1/* SPDX-License-Identifier: GPL-2.0-or-later */
      2/*
      3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
      4 * as specified in
      5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
      6 *
      7 * Copyright (c) 2021, Alibaba Group.
      8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
      9 */
     10
     11#include <linux/module.h>
     12#include <linux/crypto.h>
     13#include <linux/kernel.h>
     14#include <asm/simd.h>
     15#include <crypto/internal/simd.h>
     16#include <crypto/internal/skcipher.h>
     17#include <crypto/sm4.h>
     18#include "sm4-avx.h"
     19
     20#define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
     21
     22asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
     23				const u8 *src, int nblocks);
     24asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
     25				const u8 *src, int nblocks);
     26asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
     27				const u8 *src, u8 *iv);
     28asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
     29				const u8 *src, u8 *iv);
     30asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
     31				const u8 *src, u8 *iv);
     32
     33static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
     34			unsigned int key_len)
     35{
     36	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     37
     38	return sm4_expandkey(ctx, key, key_len);
     39}
     40
     41static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
     42{
     43	struct skcipher_walk walk;
     44	unsigned int nbytes;
     45	int err;
     46
     47	err = skcipher_walk_virt(&walk, req, false);
     48
     49	while ((nbytes = walk.nbytes) > 0) {
     50		const u8 *src = walk.src.virt.addr;
     51		u8 *dst = walk.dst.virt.addr;
     52
     53		kernel_fpu_begin();
     54		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
     55			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
     56			dst += SM4_CRYPT8_BLOCK_SIZE;
     57			src += SM4_CRYPT8_BLOCK_SIZE;
     58			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
     59		}
     60		while (nbytes >= SM4_BLOCK_SIZE) {
     61			unsigned int nblocks = min(nbytes >> 4, 4u);
     62			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
     63			dst += nblocks * SM4_BLOCK_SIZE;
     64			src += nblocks * SM4_BLOCK_SIZE;
     65			nbytes -= nblocks * SM4_BLOCK_SIZE;
     66		}
     67		kernel_fpu_end();
     68
     69		err = skcipher_walk_done(&walk, nbytes);
     70	}
     71
     72	return err;
     73}
     74
     75int sm4_avx_ecb_encrypt(struct skcipher_request *req)
     76{
     77	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     78	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     79
     80	return ecb_do_crypt(req, ctx->rkey_enc);
     81}
     82EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
     83
     84int sm4_avx_ecb_decrypt(struct skcipher_request *req)
     85{
     86	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     87	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     88
     89	return ecb_do_crypt(req, ctx->rkey_dec);
     90}
     91EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
     92
     93int sm4_cbc_encrypt(struct skcipher_request *req)
     94{
     95	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
     96	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
     97	struct skcipher_walk walk;
     98	unsigned int nbytes;
     99	int err;
    100
    101	err = skcipher_walk_virt(&walk, req, false);
    102
    103	while ((nbytes = walk.nbytes) > 0) {
    104		const u8 *iv = walk.iv;
    105		const u8 *src = walk.src.virt.addr;
    106		u8 *dst = walk.dst.virt.addr;
    107
    108		while (nbytes >= SM4_BLOCK_SIZE) {
    109			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
    110			sm4_crypt_block(ctx->rkey_enc, dst, dst);
    111			iv = dst;
    112			src += SM4_BLOCK_SIZE;
    113			dst += SM4_BLOCK_SIZE;
    114			nbytes -= SM4_BLOCK_SIZE;
    115		}
    116		if (iv != walk.iv)
    117			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    118
    119		err = skcipher_walk_done(&walk, nbytes);
    120	}
    121
    122	return err;
    123}
    124EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
    125
    126int sm4_avx_cbc_decrypt(struct skcipher_request *req,
    127			unsigned int bsize, sm4_crypt_func func)
    128{
    129	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    130	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    131	struct skcipher_walk walk;
    132	unsigned int nbytes;
    133	int err;
    134
    135	err = skcipher_walk_virt(&walk, req, false);
    136
    137	while ((nbytes = walk.nbytes) > 0) {
    138		const u8 *src = walk.src.virt.addr;
    139		u8 *dst = walk.dst.virt.addr;
    140
    141		kernel_fpu_begin();
    142
    143		while (nbytes >= bsize) {
    144			func(ctx->rkey_dec, dst, src, walk.iv);
    145			dst += bsize;
    146			src += bsize;
    147			nbytes -= bsize;
    148		}
    149
    150		while (nbytes >= SM4_BLOCK_SIZE) {
    151			u8 keystream[SM4_BLOCK_SIZE * 8];
    152			u8 iv[SM4_BLOCK_SIZE];
    153			unsigned int nblocks = min(nbytes >> 4, 8u);
    154			int i;
    155
    156			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
    157						src, nblocks);
    158
    159			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
    160			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
    161			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
    162
    163			for (i = nblocks - 1; i > 0; i--) {
    164				crypto_xor_cpy(dst, src,
    165					&keystream[i * SM4_BLOCK_SIZE],
    166					SM4_BLOCK_SIZE);
    167				src -= SM4_BLOCK_SIZE;
    168				dst -= SM4_BLOCK_SIZE;
    169			}
    170			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
    171			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    172			dst += nblocks * SM4_BLOCK_SIZE;
    173			src += (nblocks + 1) * SM4_BLOCK_SIZE;
    174			nbytes -= nblocks * SM4_BLOCK_SIZE;
    175		}
    176
    177		kernel_fpu_end();
    178		err = skcipher_walk_done(&walk, nbytes);
    179	}
    180
    181	return err;
    182}
    183EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
    184
    185static int cbc_decrypt(struct skcipher_request *req)
    186{
    187	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
    188				sm4_aesni_avx_cbc_dec_blk8);
    189}
    190
    191int sm4_cfb_encrypt(struct skcipher_request *req)
    192{
    193	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    194	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    195	struct skcipher_walk walk;
    196	unsigned int nbytes;
    197	int err;
    198
    199	err = skcipher_walk_virt(&walk, req, false);
    200
    201	while ((nbytes = walk.nbytes) > 0) {
    202		u8 keystream[SM4_BLOCK_SIZE];
    203		const u8 *iv = walk.iv;
    204		const u8 *src = walk.src.virt.addr;
    205		u8 *dst = walk.dst.virt.addr;
    206
    207		while (nbytes >= SM4_BLOCK_SIZE) {
    208			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
    209			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
    210			iv = dst;
    211			src += SM4_BLOCK_SIZE;
    212			dst += SM4_BLOCK_SIZE;
    213			nbytes -= SM4_BLOCK_SIZE;
    214		}
    215		if (iv != walk.iv)
    216			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
    217
    218		/* tail */
    219		if (walk.nbytes == walk.total && nbytes > 0) {
    220			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    221			crypto_xor_cpy(dst, src, keystream, nbytes);
    222			nbytes = 0;
    223		}
    224
    225		err = skcipher_walk_done(&walk, nbytes);
    226	}
    227
    228	return err;
    229}
    230EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
    231
    232int sm4_avx_cfb_decrypt(struct skcipher_request *req,
    233			unsigned int bsize, sm4_crypt_func func)
    234{
    235	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    236	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    237	struct skcipher_walk walk;
    238	unsigned int nbytes;
    239	int err;
    240
    241	err = skcipher_walk_virt(&walk, req, false);
    242
    243	while ((nbytes = walk.nbytes) > 0) {
    244		const u8 *src = walk.src.virt.addr;
    245		u8 *dst = walk.dst.virt.addr;
    246
    247		kernel_fpu_begin();
    248
    249		while (nbytes >= bsize) {
    250			func(ctx->rkey_enc, dst, src, walk.iv);
    251			dst += bsize;
    252			src += bsize;
    253			nbytes -= bsize;
    254		}
    255
    256		while (nbytes >= SM4_BLOCK_SIZE) {
    257			u8 keystream[SM4_BLOCK_SIZE * 8];
    258			unsigned int nblocks = min(nbytes >> 4, 8u);
    259
    260			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
    261			if (nblocks > 1)
    262				memcpy(&keystream[SM4_BLOCK_SIZE], src,
    263					(nblocks - 1) * SM4_BLOCK_SIZE);
    264			memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
    265				SM4_BLOCK_SIZE);
    266
    267			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
    268						keystream, nblocks);
    269
    270			crypto_xor_cpy(dst, src, keystream,
    271					nblocks * SM4_BLOCK_SIZE);
    272			dst += nblocks * SM4_BLOCK_SIZE;
    273			src += nblocks * SM4_BLOCK_SIZE;
    274			nbytes -= nblocks * SM4_BLOCK_SIZE;
    275		}
    276
    277		kernel_fpu_end();
    278
    279		/* tail */
    280		if (walk.nbytes == walk.total && nbytes > 0) {
    281			u8 keystream[SM4_BLOCK_SIZE];
    282
    283			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
    284			crypto_xor_cpy(dst, src, keystream, nbytes);
    285			nbytes = 0;
    286		}
    287
    288		err = skcipher_walk_done(&walk, nbytes);
    289	}
    290
    291	return err;
    292}
    293EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
    294
    295static int cfb_decrypt(struct skcipher_request *req)
    296{
    297	return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
    298				sm4_aesni_avx_cfb_dec_blk8);
    299}
    300
    301int sm4_avx_ctr_crypt(struct skcipher_request *req,
    302			unsigned int bsize, sm4_crypt_func func)
    303{
    304	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
    305	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
    306	struct skcipher_walk walk;
    307	unsigned int nbytes;
    308	int err;
    309
    310	err = skcipher_walk_virt(&walk, req, false);
    311
    312	while ((nbytes = walk.nbytes) > 0) {
    313		const u8 *src = walk.src.virt.addr;
    314		u8 *dst = walk.dst.virt.addr;
    315
    316		kernel_fpu_begin();
    317
    318		while (nbytes >= bsize) {
    319			func(ctx->rkey_enc, dst, src, walk.iv);
    320			dst += bsize;
    321			src += bsize;
    322			nbytes -= bsize;
    323		}
    324
    325		while (nbytes >= SM4_BLOCK_SIZE) {
    326			u8 keystream[SM4_BLOCK_SIZE * 8];
    327			unsigned int nblocks = min(nbytes >> 4, 8u);
    328			int i;
    329
    330			for (i = 0; i < nblocks; i++) {
    331				memcpy(&keystream[i * SM4_BLOCK_SIZE],
    332					walk.iv, SM4_BLOCK_SIZE);
    333				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
    334			}
    335			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
    336					keystream, nblocks);
    337
    338			crypto_xor_cpy(dst, src, keystream,
    339					nblocks * SM4_BLOCK_SIZE);
    340			dst += nblocks * SM4_BLOCK_SIZE;
    341			src += nblocks * SM4_BLOCK_SIZE;
    342			nbytes -= nblocks * SM4_BLOCK_SIZE;
    343		}
    344
    345		kernel_fpu_end();
    346
    347		/* tail */
    348		if (walk.nbytes == walk.total && nbytes > 0) {
    349			u8 keystream[SM4_BLOCK_SIZE];
    350
    351			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
    352			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
    353
    354			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
    355
    356			crypto_xor_cpy(dst, src, keystream, nbytes);
    357			dst += nbytes;
    358			src += nbytes;
    359			nbytes = 0;
    360		}
    361
    362		err = skcipher_walk_done(&walk, nbytes);
    363	}
    364
    365	return err;
    366}
    367EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
    368
    369static int ctr_crypt(struct skcipher_request *req)
    370{
    371	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
    372				sm4_aesni_avx_ctr_enc_blk8);
    373}
    374
    375static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
    376	{
    377		.base = {
    378			.cra_name		= "__ecb(sm4)",
    379			.cra_driver_name	= "__ecb-sm4-aesni-avx",
    380			.cra_priority		= 400,
    381			.cra_flags		= CRYPTO_ALG_INTERNAL,
    382			.cra_blocksize		= SM4_BLOCK_SIZE,
    383			.cra_ctxsize		= sizeof(struct sm4_ctx),
    384			.cra_module		= THIS_MODULE,
    385		},
    386		.min_keysize	= SM4_KEY_SIZE,
    387		.max_keysize	= SM4_KEY_SIZE,
    388		.walksize	= 8 * SM4_BLOCK_SIZE,
    389		.setkey		= sm4_skcipher_setkey,
    390		.encrypt	= sm4_avx_ecb_encrypt,
    391		.decrypt	= sm4_avx_ecb_decrypt,
    392	}, {
    393		.base = {
    394			.cra_name		= "__cbc(sm4)",
    395			.cra_driver_name	= "__cbc-sm4-aesni-avx",
    396			.cra_priority		= 400,
    397			.cra_flags		= CRYPTO_ALG_INTERNAL,
    398			.cra_blocksize		= SM4_BLOCK_SIZE,
    399			.cra_ctxsize		= sizeof(struct sm4_ctx),
    400			.cra_module		= THIS_MODULE,
    401		},
    402		.min_keysize	= SM4_KEY_SIZE,
    403		.max_keysize	= SM4_KEY_SIZE,
    404		.ivsize		= SM4_BLOCK_SIZE,
    405		.walksize	= 8 * SM4_BLOCK_SIZE,
    406		.setkey		= sm4_skcipher_setkey,
    407		.encrypt	= sm4_cbc_encrypt,
    408		.decrypt	= cbc_decrypt,
    409	}, {
    410		.base = {
    411			.cra_name		= "__cfb(sm4)",
    412			.cra_driver_name	= "__cfb-sm4-aesni-avx",
    413			.cra_priority		= 400,
    414			.cra_flags		= CRYPTO_ALG_INTERNAL,
    415			.cra_blocksize		= 1,
    416			.cra_ctxsize		= sizeof(struct sm4_ctx),
    417			.cra_module		= THIS_MODULE,
    418		},
    419		.min_keysize	= SM4_KEY_SIZE,
    420		.max_keysize	= SM4_KEY_SIZE,
    421		.ivsize		= SM4_BLOCK_SIZE,
    422		.chunksize	= SM4_BLOCK_SIZE,
    423		.walksize	= 8 * SM4_BLOCK_SIZE,
    424		.setkey		= sm4_skcipher_setkey,
    425		.encrypt	= sm4_cfb_encrypt,
    426		.decrypt	= cfb_decrypt,
    427	}, {
    428		.base = {
    429			.cra_name		= "__ctr(sm4)",
    430			.cra_driver_name	= "__ctr-sm4-aesni-avx",
    431			.cra_priority		= 400,
    432			.cra_flags		= CRYPTO_ALG_INTERNAL,
    433			.cra_blocksize		= 1,
    434			.cra_ctxsize		= sizeof(struct sm4_ctx),
    435			.cra_module		= THIS_MODULE,
    436		},
    437		.min_keysize	= SM4_KEY_SIZE,
    438		.max_keysize	= SM4_KEY_SIZE,
    439		.ivsize		= SM4_BLOCK_SIZE,
    440		.chunksize	= SM4_BLOCK_SIZE,
    441		.walksize	= 8 * SM4_BLOCK_SIZE,
    442		.setkey		= sm4_skcipher_setkey,
    443		.encrypt	= ctr_crypt,
    444		.decrypt	= ctr_crypt,
    445	}
    446};
    447
    448static struct simd_skcipher_alg *
    449simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
    450
    451static int __init sm4_init(void)
    452{
    453	const char *feature_name;
    454
    455	if (!boot_cpu_has(X86_FEATURE_AVX) ||
    456	    !boot_cpu_has(X86_FEATURE_AES) ||
    457	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
    458		pr_info("AVX or AES-NI instructions are not detected.\n");
    459		return -ENODEV;
    460	}
    461
    462	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
    463				&feature_name)) {
    464		pr_info("CPU feature '%s' is not supported.\n", feature_name);
    465		return -ENODEV;
    466	}
    467
    468	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
    469					ARRAY_SIZE(sm4_aesni_avx_skciphers),
    470					simd_sm4_aesni_avx_skciphers);
    471}
    472
    473static void __exit sm4_exit(void)
    474{
    475	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
    476					ARRAY_SIZE(sm4_aesni_avx_skciphers),
    477					simd_sm4_aesni_avx_skciphers);
    478}
    479
    480module_init(sm4_init);
    481module_exit(sm4_exit);
    482
    483MODULE_LICENSE("GPL v2");
    484MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
    485MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
    486MODULE_ALIAS_CRYPTO("sm4");
    487MODULE_ALIAS_CRYPTO("sm4-aesni-avx");