cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

zstd.c (5090B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Cryptographic API.
      4 *
      5 * Copyright (c) 2017-present, Facebook, Inc.
      6 */
      7#include <linux/crypto.h>
      8#include <linux/init.h>
      9#include <linux/interrupt.h>
     10#include <linux/mm.h>
     11#include <linux/module.h>
     12#include <linux/net.h>
     13#include <linux/vmalloc.h>
     14#include <linux/zstd.h>
     15#include <crypto/internal/scompress.h>
     16
     17
     18#define ZSTD_DEF_LEVEL	3
     19
     20struct zstd_ctx {
     21	zstd_cctx *cctx;
     22	zstd_dctx *dctx;
     23	void *cwksp;
     24	void *dwksp;
     25};
     26
     27static zstd_parameters zstd_params(void)
     28{
     29	return zstd_get_params(ZSTD_DEF_LEVEL, 0);
     30}
     31
     32static int zstd_comp_init(struct zstd_ctx *ctx)
     33{
     34	int ret = 0;
     35	const zstd_parameters params = zstd_params();
     36	const size_t wksp_size = zstd_cctx_workspace_bound(&params.cParams);
     37
     38	ctx->cwksp = vzalloc(wksp_size);
     39	if (!ctx->cwksp) {
     40		ret = -ENOMEM;
     41		goto out;
     42	}
     43
     44	ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
     45	if (!ctx->cctx) {
     46		ret = -EINVAL;
     47		goto out_free;
     48	}
     49out:
     50	return ret;
     51out_free:
     52	vfree(ctx->cwksp);
     53	goto out;
     54}
     55
     56static int zstd_decomp_init(struct zstd_ctx *ctx)
     57{
     58	int ret = 0;
     59	const size_t wksp_size = zstd_dctx_workspace_bound();
     60
     61	ctx->dwksp = vzalloc(wksp_size);
     62	if (!ctx->dwksp) {
     63		ret = -ENOMEM;
     64		goto out;
     65	}
     66
     67	ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
     68	if (!ctx->dctx) {
     69		ret = -EINVAL;
     70		goto out_free;
     71	}
     72out:
     73	return ret;
     74out_free:
     75	vfree(ctx->dwksp);
     76	goto out;
     77}
     78
     79static void zstd_comp_exit(struct zstd_ctx *ctx)
     80{
     81	vfree(ctx->cwksp);
     82	ctx->cwksp = NULL;
     83	ctx->cctx = NULL;
     84}
     85
     86static void zstd_decomp_exit(struct zstd_ctx *ctx)
     87{
     88	vfree(ctx->dwksp);
     89	ctx->dwksp = NULL;
     90	ctx->dctx = NULL;
     91}
     92
     93static int __zstd_init(void *ctx)
     94{
     95	int ret;
     96
     97	ret = zstd_comp_init(ctx);
     98	if (ret)
     99		return ret;
    100	ret = zstd_decomp_init(ctx);
    101	if (ret)
    102		zstd_comp_exit(ctx);
    103	return ret;
    104}
    105
    106static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
    107{
    108	int ret;
    109	struct zstd_ctx *ctx;
    110
    111	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
    112	if (!ctx)
    113		return ERR_PTR(-ENOMEM);
    114
    115	ret = __zstd_init(ctx);
    116	if (ret) {
    117		kfree(ctx);
    118		return ERR_PTR(ret);
    119	}
    120
    121	return ctx;
    122}
    123
    124static int zstd_init(struct crypto_tfm *tfm)
    125{
    126	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
    127
    128	return __zstd_init(ctx);
    129}
    130
    131static void __zstd_exit(void *ctx)
    132{
    133	zstd_comp_exit(ctx);
    134	zstd_decomp_exit(ctx);
    135}
    136
    137static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
    138{
    139	__zstd_exit(ctx);
    140	kfree_sensitive(ctx);
    141}
    142
    143static void zstd_exit(struct crypto_tfm *tfm)
    144{
    145	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
    146
    147	__zstd_exit(ctx);
    148}
    149
    150static int __zstd_compress(const u8 *src, unsigned int slen,
    151			   u8 *dst, unsigned int *dlen, void *ctx)
    152{
    153	size_t out_len;
    154	struct zstd_ctx *zctx = ctx;
    155	const zstd_parameters params = zstd_params();
    156
    157	out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, &params);
    158	if (zstd_is_error(out_len))
    159		return -EINVAL;
    160	*dlen = out_len;
    161	return 0;
    162}
    163
    164static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
    165			 unsigned int slen, u8 *dst, unsigned int *dlen)
    166{
    167	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
    168
    169	return __zstd_compress(src, slen, dst, dlen, ctx);
    170}
    171
    172static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
    173			  unsigned int slen, u8 *dst, unsigned int *dlen,
    174			  void *ctx)
    175{
    176	return __zstd_compress(src, slen, dst, dlen, ctx);
    177}
    178
    179static int __zstd_decompress(const u8 *src, unsigned int slen,
    180			     u8 *dst, unsigned int *dlen, void *ctx)
    181{
    182	size_t out_len;
    183	struct zstd_ctx *zctx = ctx;
    184
    185	out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
    186	if (zstd_is_error(out_len))
    187		return -EINVAL;
    188	*dlen = out_len;
    189	return 0;
    190}
    191
    192static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
    193			   unsigned int slen, u8 *dst, unsigned int *dlen)
    194{
    195	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
    196
    197	return __zstd_decompress(src, slen, dst, dlen, ctx);
    198}
    199
    200static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
    201			    unsigned int slen, u8 *dst, unsigned int *dlen,
    202			    void *ctx)
    203{
    204	return __zstd_decompress(src, slen, dst, dlen, ctx);
    205}
    206
    207static struct crypto_alg alg = {
    208	.cra_name		= "zstd",
    209	.cra_driver_name	= "zstd-generic",
    210	.cra_flags		= CRYPTO_ALG_TYPE_COMPRESS,
    211	.cra_ctxsize		= sizeof(struct zstd_ctx),
    212	.cra_module		= THIS_MODULE,
    213	.cra_init		= zstd_init,
    214	.cra_exit		= zstd_exit,
    215	.cra_u			= { .compress = {
    216	.coa_compress		= zstd_compress,
    217	.coa_decompress		= zstd_decompress } }
    218};
    219
    220static struct scomp_alg scomp = {
    221	.alloc_ctx		= zstd_alloc_ctx,
    222	.free_ctx		= zstd_free_ctx,
    223	.compress		= zstd_scompress,
    224	.decompress		= zstd_sdecompress,
    225	.base			= {
    226		.cra_name	= "zstd",
    227		.cra_driver_name = "zstd-scomp",
    228		.cra_module	 = THIS_MODULE,
    229	}
    230};
    231
    232static int __init zstd_mod_init(void)
    233{
    234	int ret;
    235
    236	ret = crypto_register_alg(&alg);
    237	if (ret)
    238		return ret;
    239
    240	ret = crypto_register_scomp(&scomp);
    241	if (ret)
    242		crypto_unregister_alg(&alg);
    243
    244	return ret;
    245}
    246
    247static void __exit zstd_mod_fini(void)
    248{
    249	crypto_unregister_alg(&alg);
    250	crypto_unregister_scomp(&scomp);
    251}
    252
    253subsys_initcall(zstd_mod_init);
    254module_exit(zstd_mod_fini);
    255
    256MODULE_LICENSE("GPL");
    257MODULE_DESCRIPTION("Zstd Compression Algorithm");
    258MODULE_ALIAS_CRYPTO("zstd");