cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

decompress_unzstd.c (10484B)


      1// SPDX-License-Identifier: GPL-2.0
      2
      3/*
      4 * Important notes about in-place decompression
      5 *
      6 * At least on x86, the kernel is decompressed in place: the compressed data
      7 * is placed to the end of the output buffer, and the decompressor overwrites
      8 * most of the compressed data. There must be enough safety margin to
      9 * guarantee that the write position is always behind the read position.
     10 *
     11 * The safety margin for ZSTD with a 128 KB block size is calculated below.
     12 * Note that the margin with ZSTD is bigger than with GZIP or XZ!
     13 *
     14 * The worst case for in-place decompression is that the beginning of
     15 * the file is compressed extremely well, and the rest of the file is
     16 * uncompressible. Thus, we must look for worst-case expansion when the
     17 * compressor is encoding uncompressible data.
     18 *
     19 * The structure of the .zst file in case of a compressed kernel is as follows.
     20 * Maximum sizes (as bytes) of the fields are in parenthesis.
     21 *
     22 *    Frame Header: (18)
     23 *    Blocks: (N)
     24 *    Checksum: (4)
     25 *
     26 * The frame header and checksum overhead is at most 22 bytes.
     27 *
     28 * ZSTD stores the data in blocks. Each block has a header whose size is
     29 * a 3 bytes. After the block header, there is up to 128 KB of payload.
     30 * The maximum uncompressed size of the payload is 128 KB. The minimum
     31 * uncompressed size of the payload is never less than the payload size
     32 * (excluding the block header).
     33 *
     34 * The assumption, that the uncompressed size of the payload is never
     35 * smaller than the payload itself, is valid only when talking about
     36 * the payload as a whole. It is possible that the payload has parts where
     37 * the decompressor consumes more input than it produces output. Calculating
     38 * the worst case for this would be tricky. Instead of trying to do that,
     39 * let's simply make sure that the decompressor never overwrites any bytes
     40 * of the payload which it is currently reading.
     41 *
     42 * Now we have enough information to calculate the safety margin. We need
     43 *   - 22 bytes for the .zst file format headers;
     44 *   - 3 bytes per every 128 KiB of uncompressed size (one block header per
     45 *     block); and
     46 *   - 128 KiB (biggest possible zstd block size) to make sure that the
     47 *     decompressor never overwrites anything from the block it is currently
     48 *     reading.
     49 *
     50 * We get the following formula:
     51 *
     52 *    safety_margin = 22 + uncompressed_size * 3 / 131072 + 131072
     53 *                 <= 22 + (uncompressed_size >> 15) + 131072
     54 */
     55
     56/*
     57 * Preboot environments #include "path/to/decompress_unzstd.c".
     58 * All of the source files we depend on must be #included.
     59 * zstd's only source dependency is xxhash, which has no source
     60 * dependencies.
     61 *
     62 * When UNZSTD_PREBOOT is defined we declare __decompress(), which is
     63 * used for kernel decompression, instead of unzstd().
     64 *
     65 * Define __DISABLE_EXPORTS in preboot environments to prevent symbols
     66 * from xxhash and zstd from being exported by the EXPORT_SYMBOL macro.
     67 */
     68#ifdef STATIC
     69# define UNZSTD_PREBOOT
     70# include "xxhash.c"
     71# include "zstd/decompress_sources.h"
     72#endif
     73
     74#include <linux/decompress/mm.h>
     75#include <linux/kernel.h>
     76#include <linux/zstd.h>
     77
     78/* 128MB is the maximum window size supported by zstd. */
     79#define ZSTD_WINDOWSIZE_MAX	(1 << ZSTD_WINDOWLOG_MAX)
     80/*
     81 * Size of the input and output buffers in multi-call mode.
     82 * Pick a larger size because it isn't used during kernel decompression,
     83 * since that is single pass, and we have to allocate a large buffer for
     84 * zstd's window anyway. The larger size speeds up initramfs decompression.
     85 */
     86#define ZSTD_IOBUF_SIZE		(1 << 17)
     87
     88static int INIT handle_zstd_error(size_t ret, void (*error)(char *x))
     89{
     90	const zstd_error_code err = zstd_get_error_code(ret);
     91
     92	if (!zstd_is_error(ret))
     93		return 0;
     94
     95	/*
     96	 * zstd_get_error_name() cannot be used because error takes a char *
     97	 * not a const char *
     98	 */
     99	switch (err) {
    100	case ZSTD_error_memory_allocation:
    101		error("ZSTD decompressor ran out of memory");
    102		break;
    103	case ZSTD_error_prefix_unknown:
    104		error("Input is not in the ZSTD format (wrong magic bytes)");
    105		break;
    106	case ZSTD_error_dstSize_tooSmall:
    107	case ZSTD_error_corruption_detected:
    108	case ZSTD_error_checksum_wrong:
    109		error("ZSTD-compressed data is corrupt");
    110		break;
    111	default:
    112		error("ZSTD-compressed data is probably corrupt");
    113		break;
    114	}
    115	return -1;
    116}
    117
    118/*
    119 * Handle the case where we have the entire input and output in one segment.
    120 * We can allocate less memory (no circular buffer for the sliding window),
    121 * and avoid some memcpy() calls.
    122 */
    123static int INIT decompress_single(const u8 *in_buf, long in_len, u8 *out_buf,
    124				  long out_len, long *in_pos,
    125				  void (*error)(char *x))
    126{
    127	const size_t wksp_size = zstd_dctx_workspace_bound();
    128	void *wksp = large_malloc(wksp_size);
    129	zstd_dctx *dctx = zstd_init_dctx(wksp, wksp_size);
    130	int err;
    131	size_t ret;
    132
    133	if (dctx == NULL) {
    134		error("Out of memory while allocating zstd_dctx");
    135		err = -1;
    136		goto out;
    137	}
    138	/*
    139	 * Find out how large the frame actually is, there may be junk at
    140	 * the end of the frame that zstd_decompress_dctx() can't handle.
    141	 */
    142	ret = zstd_find_frame_compressed_size(in_buf, in_len);
    143	err = handle_zstd_error(ret, error);
    144	if (err)
    145		goto out;
    146	in_len = (long)ret;
    147
    148	ret = zstd_decompress_dctx(dctx, out_buf, out_len, in_buf, in_len);
    149	err = handle_zstd_error(ret, error);
    150	if (err)
    151		goto out;
    152
    153	if (in_pos != NULL)
    154		*in_pos = in_len;
    155
    156	err = 0;
    157out:
    158	if (wksp != NULL)
    159		large_free(wksp);
    160	return err;
    161}
    162
    163static int INIT __unzstd(unsigned char *in_buf, long in_len,
    164			 long (*fill)(void*, unsigned long),
    165			 long (*flush)(void*, unsigned long),
    166			 unsigned char *out_buf, long out_len,
    167			 long *in_pos,
    168			 void (*error)(char *x))
    169{
    170	zstd_in_buffer in;
    171	zstd_out_buffer out;
    172	zstd_frame_header header;
    173	void *in_allocated = NULL;
    174	void *out_allocated = NULL;
    175	void *wksp = NULL;
    176	size_t wksp_size;
    177	zstd_dstream *dstream;
    178	int err;
    179	size_t ret;
    180
    181	/*
    182	 * ZSTD decompression code won't be happy if the buffer size is so big
    183	 * that its end address overflows. When the size is not provided, make
    184	 * it as big as possible without having the end address overflow.
    185	 */
    186	if (out_len == 0)
    187		out_len = UINTPTR_MAX - (uintptr_t)out_buf;
    188
    189	if (fill == NULL && flush == NULL)
    190		/*
    191		 * We can decompress faster and with less memory when we have a
    192		 * single chunk.
    193		 */
    194		return decompress_single(in_buf, in_len, out_buf, out_len,
    195					 in_pos, error);
    196
    197	/*
    198	 * If in_buf is not provided, we must be using fill(), so allocate
    199	 * a large enough buffer. If it is provided, it must be at least
    200	 * ZSTD_IOBUF_SIZE large.
    201	 */
    202	if (in_buf == NULL) {
    203		in_allocated = large_malloc(ZSTD_IOBUF_SIZE);
    204		if (in_allocated == NULL) {
    205			error("Out of memory while allocating input buffer");
    206			err = -1;
    207			goto out;
    208		}
    209		in_buf = in_allocated;
    210		in_len = 0;
    211	}
    212	/* Read the first chunk, since we need to decode the frame header. */
    213	if (fill != NULL)
    214		in_len = fill(in_buf, ZSTD_IOBUF_SIZE);
    215	if (in_len < 0) {
    216		error("ZSTD-compressed data is truncated");
    217		err = -1;
    218		goto out;
    219	}
    220	/* Set the first non-empty input buffer. */
    221	in.src = in_buf;
    222	in.pos = 0;
    223	in.size = in_len;
    224	/* Allocate the output buffer if we are using flush(). */
    225	if (flush != NULL) {
    226		out_allocated = large_malloc(ZSTD_IOBUF_SIZE);
    227		if (out_allocated == NULL) {
    228			error("Out of memory while allocating output buffer");
    229			err = -1;
    230			goto out;
    231		}
    232		out_buf = out_allocated;
    233		out_len = ZSTD_IOBUF_SIZE;
    234	}
    235	/* Set the output buffer. */
    236	out.dst = out_buf;
    237	out.pos = 0;
    238	out.size = out_len;
    239
    240	/*
    241	 * We need to know the window size to allocate the zstd_dstream.
    242	 * Since we are streaming, we need to allocate a buffer for the sliding
    243	 * window. The window size varies from 1 KB to ZSTD_WINDOWSIZE_MAX
    244	 * (8 MB), so it is important to use the actual value so as not to
    245	 * waste memory when it is smaller.
    246	 */
    247	ret = zstd_get_frame_header(&header, in.src, in.size);
    248	err = handle_zstd_error(ret, error);
    249	if (err)
    250		goto out;
    251	if (ret != 0) {
    252		error("ZSTD-compressed data has an incomplete frame header");
    253		err = -1;
    254		goto out;
    255	}
    256	if (header.windowSize > ZSTD_WINDOWSIZE_MAX) {
    257		error("ZSTD-compressed data has too large a window size");
    258		err = -1;
    259		goto out;
    260	}
    261
    262	/*
    263	 * Allocate the zstd_dstream now that we know how much memory is
    264	 * required.
    265	 */
    266	wksp_size = zstd_dstream_workspace_bound(header.windowSize);
    267	wksp = large_malloc(wksp_size);
    268	dstream = zstd_init_dstream(header.windowSize, wksp, wksp_size);
    269	if (dstream == NULL) {
    270		error("Out of memory while allocating ZSTD_DStream");
    271		err = -1;
    272		goto out;
    273	}
    274
    275	/*
    276	 * Decompression loop:
    277	 * Read more data if necessary (error if no more data can be read).
    278	 * Call the decompression function, which returns 0 when finished.
    279	 * Flush any data produced if using flush().
    280	 */
    281	if (in_pos != NULL)
    282		*in_pos = 0;
    283	do {
    284		/*
    285		 * If we need to reload data, either we have fill() and can
    286		 * try to get more data, or we don't and the input is truncated.
    287		 */
    288		if (in.pos == in.size) {
    289			if (in_pos != NULL)
    290				*in_pos += in.pos;
    291			in_len = fill ? fill(in_buf, ZSTD_IOBUF_SIZE) : -1;
    292			if (in_len < 0) {
    293				error("ZSTD-compressed data is truncated");
    294				err = -1;
    295				goto out;
    296			}
    297			in.pos = 0;
    298			in.size = in_len;
    299		}
    300		/* Returns zero when the frame is complete. */
    301		ret = zstd_decompress_stream(dstream, &out, &in);
    302		err = handle_zstd_error(ret, error);
    303		if (err)
    304			goto out;
    305		/* Flush all of the data produced if using flush(). */
    306		if (flush != NULL && out.pos > 0) {
    307			if (out.pos != flush(out.dst, out.pos)) {
    308				error("Failed to flush()");
    309				err = -1;
    310				goto out;
    311			}
    312			out.pos = 0;
    313		}
    314	} while (ret != 0);
    315
    316	if (in_pos != NULL)
    317		*in_pos += in.pos;
    318
    319	err = 0;
    320out:
    321	if (in_allocated != NULL)
    322		large_free(in_allocated);
    323	if (out_allocated != NULL)
    324		large_free(out_allocated);
    325	if (wksp != NULL)
    326		large_free(wksp);
    327	return err;
    328}
    329
    330#ifndef UNZSTD_PREBOOT
    331STATIC int INIT unzstd(unsigned char *buf, long len,
    332		       long (*fill)(void*, unsigned long),
    333		       long (*flush)(void*, unsigned long),
    334		       unsigned char *out_buf,
    335		       long *pos,
    336		       void (*error)(char *x))
    337{
    338	return __unzstd(buf, len, fill, flush, out_buf, 0, pos, error);
    339}
    340#else
    341STATIC int INIT __decompress(unsigned char *buf, long len,
    342			     long (*fill)(void*, unsigned long),
    343			     long (*flush)(void*, unsigned long),
    344			     unsigned char *out_buf, long out_len,
    345			     long *pos,
    346			     void (*error)(char *x))
    347{
    348	return __unzstd(buf, len, fill, flush, out_buf, out_len, pos, error);
    349}
    350#endif