sm4_aesni_avx_glue.c (12594B)
1/* SPDX-License-Identifier: GPL-2.0-or-later */ 2/* 3 * SM4 Cipher Algorithm, AES-NI/AVX optimized. 4 * as specified in 5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html 6 * 7 * Copyright (c) 2021, Alibaba Group. 8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com> 9 */ 10 11#include <linux/module.h> 12#include <linux/crypto.h> 13#include <linux/kernel.h> 14#include <asm/simd.h> 15#include <crypto/internal/simd.h> 16#include <crypto/internal/skcipher.h> 17#include <crypto/sm4.h> 18#include "sm4-avx.h" 19 20#define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8) 21 22asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst, 23 const u8 *src, int nblocks); 24asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst, 25 const u8 *src, int nblocks); 26asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst, 27 const u8 *src, u8 *iv); 28asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst, 29 const u8 *src, u8 *iv); 30asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst, 31 const u8 *src, u8 *iv); 32 33static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key, 34 unsigned int key_len) 35{ 36 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 37 38 return sm4_expandkey(ctx, key, key_len); 39} 40 41static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey) 42{ 43 struct skcipher_walk walk; 44 unsigned int nbytes; 45 int err; 46 47 err = skcipher_walk_virt(&walk, req, false); 48 49 while ((nbytes = walk.nbytes) > 0) { 50 const u8 *src = walk.src.virt.addr; 51 u8 *dst = walk.dst.virt.addr; 52 53 kernel_fpu_begin(); 54 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) { 55 sm4_aesni_avx_crypt8(rkey, dst, src, 8); 56 dst += SM4_CRYPT8_BLOCK_SIZE; 57 src += SM4_CRYPT8_BLOCK_SIZE; 58 nbytes -= SM4_CRYPT8_BLOCK_SIZE; 59 } 60 while (nbytes >= SM4_BLOCK_SIZE) { 61 unsigned int nblocks = min(nbytes >> 4, 4u); 62 sm4_aesni_avx_crypt4(rkey, dst, src, nblocks); 63 dst += nblocks * SM4_BLOCK_SIZE; 64 src += nblocks * SM4_BLOCK_SIZE; 65 nbytes -= nblocks * SM4_BLOCK_SIZE; 66 } 67 kernel_fpu_end(); 68 69 err = skcipher_walk_done(&walk, nbytes); 70 } 71 72 return err; 73} 74 75int sm4_avx_ecb_encrypt(struct skcipher_request *req) 76{ 77 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 78 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 79 80 return ecb_do_crypt(req, ctx->rkey_enc); 81} 82EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt); 83 84int sm4_avx_ecb_decrypt(struct skcipher_request *req) 85{ 86 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 87 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 88 89 return ecb_do_crypt(req, ctx->rkey_dec); 90} 91EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt); 92 93int sm4_cbc_encrypt(struct skcipher_request *req) 94{ 95 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 96 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 97 struct skcipher_walk walk; 98 unsigned int nbytes; 99 int err; 100 101 err = skcipher_walk_virt(&walk, req, false); 102 103 while ((nbytes = walk.nbytes) > 0) { 104 const u8 *iv = walk.iv; 105 const u8 *src = walk.src.virt.addr; 106 u8 *dst = walk.dst.virt.addr; 107 108 while (nbytes >= SM4_BLOCK_SIZE) { 109 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE); 110 sm4_crypt_block(ctx->rkey_enc, dst, dst); 111 iv = dst; 112 src += SM4_BLOCK_SIZE; 113 dst += SM4_BLOCK_SIZE; 114 nbytes -= SM4_BLOCK_SIZE; 115 } 116 if (iv != walk.iv) 117 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 118 119 err = skcipher_walk_done(&walk, nbytes); 120 } 121 122 return err; 123} 124EXPORT_SYMBOL_GPL(sm4_cbc_encrypt); 125 126int sm4_avx_cbc_decrypt(struct skcipher_request *req, 127 unsigned int bsize, sm4_crypt_func func) 128{ 129 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 130 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 131 struct skcipher_walk walk; 132 unsigned int nbytes; 133 int err; 134 135 err = skcipher_walk_virt(&walk, req, false); 136 137 while ((nbytes = walk.nbytes) > 0) { 138 const u8 *src = walk.src.virt.addr; 139 u8 *dst = walk.dst.virt.addr; 140 141 kernel_fpu_begin(); 142 143 while (nbytes >= bsize) { 144 func(ctx->rkey_dec, dst, src, walk.iv); 145 dst += bsize; 146 src += bsize; 147 nbytes -= bsize; 148 } 149 150 while (nbytes >= SM4_BLOCK_SIZE) { 151 u8 keystream[SM4_BLOCK_SIZE * 8]; 152 u8 iv[SM4_BLOCK_SIZE]; 153 unsigned int nblocks = min(nbytes >> 4, 8u); 154 int i; 155 156 sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream, 157 src, nblocks); 158 159 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE; 160 dst += (nblocks - 1) * SM4_BLOCK_SIZE; 161 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE); 162 163 for (i = nblocks - 1; i > 0; i--) { 164 crypto_xor_cpy(dst, src, 165 &keystream[i * SM4_BLOCK_SIZE], 166 SM4_BLOCK_SIZE); 167 src -= SM4_BLOCK_SIZE; 168 dst -= SM4_BLOCK_SIZE; 169 } 170 crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE); 171 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 172 dst += nblocks * SM4_BLOCK_SIZE; 173 src += (nblocks + 1) * SM4_BLOCK_SIZE; 174 nbytes -= nblocks * SM4_BLOCK_SIZE; 175 } 176 177 kernel_fpu_end(); 178 err = skcipher_walk_done(&walk, nbytes); 179 } 180 181 return err; 182} 183EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt); 184 185static int cbc_decrypt(struct skcipher_request *req) 186{ 187 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE, 188 sm4_aesni_avx_cbc_dec_blk8); 189} 190 191int sm4_cfb_encrypt(struct skcipher_request *req) 192{ 193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 195 struct skcipher_walk walk; 196 unsigned int nbytes; 197 int err; 198 199 err = skcipher_walk_virt(&walk, req, false); 200 201 while ((nbytes = walk.nbytes) > 0) { 202 u8 keystream[SM4_BLOCK_SIZE]; 203 const u8 *iv = walk.iv; 204 const u8 *src = walk.src.virt.addr; 205 u8 *dst = walk.dst.virt.addr; 206 207 while (nbytes >= SM4_BLOCK_SIZE) { 208 sm4_crypt_block(ctx->rkey_enc, keystream, iv); 209 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE); 210 iv = dst; 211 src += SM4_BLOCK_SIZE; 212 dst += SM4_BLOCK_SIZE; 213 nbytes -= SM4_BLOCK_SIZE; 214 } 215 if (iv != walk.iv) 216 memcpy(walk.iv, iv, SM4_BLOCK_SIZE); 217 218 /* tail */ 219 if (walk.nbytes == walk.total && nbytes > 0) { 220 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 221 crypto_xor_cpy(dst, src, keystream, nbytes); 222 nbytes = 0; 223 } 224 225 err = skcipher_walk_done(&walk, nbytes); 226 } 227 228 return err; 229} 230EXPORT_SYMBOL_GPL(sm4_cfb_encrypt); 231 232int sm4_avx_cfb_decrypt(struct skcipher_request *req, 233 unsigned int bsize, sm4_crypt_func func) 234{ 235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 237 struct skcipher_walk walk; 238 unsigned int nbytes; 239 int err; 240 241 err = skcipher_walk_virt(&walk, req, false); 242 243 while ((nbytes = walk.nbytes) > 0) { 244 const u8 *src = walk.src.virt.addr; 245 u8 *dst = walk.dst.virt.addr; 246 247 kernel_fpu_begin(); 248 249 while (nbytes >= bsize) { 250 func(ctx->rkey_enc, dst, src, walk.iv); 251 dst += bsize; 252 src += bsize; 253 nbytes -= bsize; 254 } 255 256 while (nbytes >= SM4_BLOCK_SIZE) { 257 u8 keystream[SM4_BLOCK_SIZE * 8]; 258 unsigned int nblocks = min(nbytes >> 4, 8u); 259 260 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE); 261 if (nblocks > 1) 262 memcpy(&keystream[SM4_BLOCK_SIZE], src, 263 (nblocks - 1) * SM4_BLOCK_SIZE); 264 memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE, 265 SM4_BLOCK_SIZE); 266 267 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream, 268 keystream, nblocks); 269 270 crypto_xor_cpy(dst, src, keystream, 271 nblocks * SM4_BLOCK_SIZE); 272 dst += nblocks * SM4_BLOCK_SIZE; 273 src += nblocks * SM4_BLOCK_SIZE; 274 nbytes -= nblocks * SM4_BLOCK_SIZE; 275 } 276 277 kernel_fpu_end(); 278 279 /* tail */ 280 if (walk.nbytes == walk.total && nbytes > 0) { 281 u8 keystream[SM4_BLOCK_SIZE]; 282 283 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv); 284 crypto_xor_cpy(dst, src, keystream, nbytes); 285 nbytes = 0; 286 } 287 288 err = skcipher_walk_done(&walk, nbytes); 289 } 290 291 return err; 292} 293EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt); 294 295static int cfb_decrypt(struct skcipher_request *req) 296{ 297 return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE, 298 sm4_aesni_avx_cfb_dec_blk8); 299} 300 301int sm4_avx_ctr_crypt(struct skcipher_request *req, 302 unsigned int bsize, sm4_crypt_func func) 303{ 304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 305 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); 306 struct skcipher_walk walk; 307 unsigned int nbytes; 308 int err; 309 310 err = skcipher_walk_virt(&walk, req, false); 311 312 while ((nbytes = walk.nbytes) > 0) { 313 const u8 *src = walk.src.virt.addr; 314 u8 *dst = walk.dst.virt.addr; 315 316 kernel_fpu_begin(); 317 318 while (nbytes >= bsize) { 319 func(ctx->rkey_enc, dst, src, walk.iv); 320 dst += bsize; 321 src += bsize; 322 nbytes -= bsize; 323 } 324 325 while (nbytes >= SM4_BLOCK_SIZE) { 326 u8 keystream[SM4_BLOCK_SIZE * 8]; 327 unsigned int nblocks = min(nbytes >> 4, 8u); 328 int i; 329 330 for (i = 0; i < nblocks; i++) { 331 memcpy(&keystream[i * SM4_BLOCK_SIZE], 332 walk.iv, SM4_BLOCK_SIZE); 333 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 334 } 335 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream, 336 keystream, nblocks); 337 338 crypto_xor_cpy(dst, src, keystream, 339 nblocks * SM4_BLOCK_SIZE); 340 dst += nblocks * SM4_BLOCK_SIZE; 341 src += nblocks * SM4_BLOCK_SIZE; 342 nbytes -= nblocks * SM4_BLOCK_SIZE; 343 } 344 345 kernel_fpu_end(); 346 347 /* tail */ 348 if (walk.nbytes == walk.total && nbytes > 0) { 349 u8 keystream[SM4_BLOCK_SIZE]; 350 351 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE); 352 crypto_inc(walk.iv, SM4_BLOCK_SIZE); 353 354 sm4_crypt_block(ctx->rkey_enc, keystream, keystream); 355 356 crypto_xor_cpy(dst, src, keystream, nbytes); 357 dst += nbytes; 358 src += nbytes; 359 nbytes = 0; 360 } 361 362 err = skcipher_walk_done(&walk, nbytes); 363 } 364 365 return err; 366} 367EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt); 368 369static int ctr_crypt(struct skcipher_request *req) 370{ 371 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE, 372 sm4_aesni_avx_ctr_enc_blk8); 373} 374 375static struct skcipher_alg sm4_aesni_avx_skciphers[] = { 376 { 377 .base = { 378 .cra_name = "__ecb(sm4)", 379 .cra_driver_name = "__ecb-sm4-aesni-avx", 380 .cra_priority = 400, 381 .cra_flags = CRYPTO_ALG_INTERNAL, 382 .cra_blocksize = SM4_BLOCK_SIZE, 383 .cra_ctxsize = sizeof(struct sm4_ctx), 384 .cra_module = THIS_MODULE, 385 }, 386 .min_keysize = SM4_KEY_SIZE, 387 .max_keysize = SM4_KEY_SIZE, 388 .walksize = 8 * SM4_BLOCK_SIZE, 389 .setkey = sm4_skcipher_setkey, 390 .encrypt = sm4_avx_ecb_encrypt, 391 .decrypt = sm4_avx_ecb_decrypt, 392 }, { 393 .base = { 394 .cra_name = "__cbc(sm4)", 395 .cra_driver_name = "__cbc-sm4-aesni-avx", 396 .cra_priority = 400, 397 .cra_flags = CRYPTO_ALG_INTERNAL, 398 .cra_blocksize = SM4_BLOCK_SIZE, 399 .cra_ctxsize = sizeof(struct sm4_ctx), 400 .cra_module = THIS_MODULE, 401 }, 402 .min_keysize = SM4_KEY_SIZE, 403 .max_keysize = SM4_KEY_SIZE, 404 .ivsize = SM4_BLOCK_SIZE, 405 .walksize = 8 * SM4_BLOCK_SIZE, 406 .setkey = sm4_skcipher_setkey, 407 .encrypt = sm4_cbc_encrypt, 408 .decrypt = cbc_decrypt, 409 }, { 410 .base = { 411 .cra_name = "__cfb(sm4)", 412 .cra_driver_name = "__cfb-sm4-aesni-avx", 413 .cra_priority = 400, 414 .cra_flags = CRYPTO_ALG_INTERNAL, 415 .cra_blocksize = 1, 416 .cra_ctxsize = sizeof(struct sm4_ctx), 417 .cra_module = THIS_MODULE, 418 }, 419 .min_keysize = SM4_KEY_SIZE, 420 .max_keysize = SM4_KEY_SIZE, 421 .ivsize = SM4_BLOCK_SIZE, 422 .chunksize = SM4_BLOCK_SIZE, 423 .walksize = 8 * SM4_BLOCK_SIZE, 424 .setkey = sm4_skcipher_setkey, 425 .encrypt = sm4_cfb_encrypt, 426 .decrypt = cfb_decrypt, 427 }, { 428 .base = { 429 .cra_name = "__ctr(sm4)", 430 .cra_driver_name = "__ctr-sm4-aesni-avx", 431 .cra_priority = 400, 432 .cra_flags = CRYPTO_ALG_INTERNAL, 433 .cra_blocksize = 1, 434 .cra_ctxsize = sizeof(struct sm4_ctx), 435 .cra_module = THIS_MODULE, 436 }, 437 .min_keysize = SM4_KEY_SIZE, 438 .max_keysize = SM4_KEY_SIZE, 439 .ivsize = SM4_BLOCK_SIZE, 440 .chunksize = SM4_BLOCK_SIZE, 441 .walksize = 8 * SM4_BLOCK_SIZE, 442 .setkey = sm4_skcipher_setkey, 443 .encrypt = ctr_crypt, 444 .decrypt = ctr_crypt, 445 } 446}; 447 448static struct simd_skcipher_alg * 449simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)]; 450 451static int __init sm4_init(void) 452{ 453 const char *feature_name; 454 455 if (!boot_cpu_has(X86_FEATURE_AVX) || 456 !boot_cpu_has(X86_FEATURE_AES) || 457 !boot_cpu_has(X86_FEATURE_OSXSAVE)) { 458 pr_info("AVX or AES-NI instructions are not detected.\n"); 459 return -ENODEV; 460 } 461 462 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, 463 &feature_name)) { 464 pr_info("CPU feature '%s' is not supported.\n", feature_name); 465 return -ENODEV; 466 } 467 468 return simd_register_skciphers_compat(sm4_aesni_avx_skciphers, 469 ARRAY_SIZE(sm4_aesni_avx_skciphers), 470 simd_sm4_aesni_avx_skciphers); 471} 472 473static void __exit sm4_exit(void) 474{ 475 simd_unregister_skciphers(sm4_aesni_avx_skciphers, 476 ARRAY_SIZE(sm4_aesni_avx_skciphers), 477 simd_sm4_aesni_avx_skciphers); 478} 479 480module_init(sm4_init); 481module_exit(sm4_exit); 482 483MODULE_LICENSE("GPL v2"); 484MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>"); 485MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized"); 486MODULE_ALIAS_CRYPTO("sm4"); 487MODULE_ALIAS_CRYPTO("sm4-aesni-avx");