entropy_common.c (13551B)
1/* ****************************************************************** 2 * Common functions of New Generation Entropy library 3 * Copyright (c) Yann Collet, Facebook, Inc. 4 * 5 * You can contact the author at : 6 * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy 7 * - Public forum : https://groups.google.com/forum/#!forum/lz4c 8 * 9 * This source code is licensed under both the BSD-style license (found in the 10 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 11 * in the COPYING file in the root directory of this source tree). 12 * You may select, at your option, one of the above-listed licenses. 13****************************************************************** */ 14 15/* ************************************* 16* Dependencies 17***************************************/ 18#include "mem.h" 19#include "error_private.h" /* ERR_*, ERROR */ 20#define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ 21#include "fse.h" 22#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */ 23#include "huf.h" 24 25 26/*=== Version ===*/ 27unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } 28 29 30/*=== Error Management ===*/ 31unsigned FSE_isError(size_t code) { return ERR_isError(code); } 32const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } 33 34unsigned HUF_isError(size_t code) { return ERR_isError(code); } 35const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } 36 37 38/*-************************************************************** 39* FSE NCount encoding-decoding 40****************************************************************/ 41static U32 FSE_ctz(U32 val) 42{ 43 assert(val != 0); 44 { 45# if (__GNUC__ >= 3) /* GCC Intrinsic */ 46 return __builtin_ctz(val); 47# else /* Software version */ 48 U32 count = 0; 49 while ((val & 1) == 0) { 50 val >>= 1; 51 ++count; 52 } 53 return count; 54# endif 55 } 56} 57 58FORCE_INLINE_TEMPLATE 59size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 60 const void* headerBuffer, size_t hbSize) 61{ 62 const BYTE* const istart = (const BYTE*) headerBuffer; 63 const BYTE* const iend = istart + hbSize; 64 const BYTE* ip = istart; 65 int nbBits; 66 int remaining; 67 int threshold; 68 U32 bitStream; 69 int bitCount; 70 unsigned charnum = 0; 71 unsigned const maxSV1 = *maxSVPtr + 1; 72 int previous0 = 0; 73 74 if (hbSize < 8) { 75 /* This function only works when hbSize >= 8 */ 76 char buffer[8] = {0}; 77 ZSTD_memcpy(buffer, headerBuffer, hbSize); 78 { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, 79 buffer, sizeof(buffer)); 80 if (FSE_isError(countSize)) return countSize; 81 if (countSize > hbSize) return ERROR(corruption_detected); 82 return countSize; 83 } } 84 assert(hbSize >= 8); 85 86 /* init */ 87 ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ 88 bitStream = MEM_readLE32(ip); 89 nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ 90 if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); 91 bitStream >>= 4; 92 bitCount = 4; 93 *tableLogPtr = nbBits; 94 remaining = (1<<nbBits)+1; 95 threshold = 1<<nbBits; 96 nbBits++; 97 98 for (;;) { 99 if (previous0) { 100 /* Count the number of repeats. Each time the 101 * 2-bit repeat code is 0b11 there is another 102 * repeat. 103 * Avoid UB by setting the high bit to 1. 104 */ 105 int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; 106 while (repeats >= 12) { 107 charnum += 3 * 12; 108 if (LIKELY(ip <= iend-7)) { 109 ip += 3; 110 } else { 111 bitCount -= (int)(8 * (iend - 7 - ip)); 112 bitCount &= 31; 113 ip = iend - 4; 114 } 115 bitStream = MEM_readLE32(ip) >> bitCount; 116 repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; 117 } 118 charnum += 3 * repeats; 119 bitStream >>= 2 * repeats; 120 bitCount += 2 * repeats; 121 122 /* Add the final repeat which isn't 0b11. */ 123 assert((bitStream & 3) < 3); 124 charnum += bitStream & 3; 125 bitCount += 2; 126 127 /* This is an error, but break and return an error 128 * at the end, because returning out of a loop makes 129 * it harder for the compiler to optimize. 130 */ 131 if (charnum >= maxSV1) break; 132 133 /* We don't need to set the normalized count to 0 134 * because we already memset the whole buffer to 0. 135 */ 136 137 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 138 assert((bitCount >> 3) <= 3); /* For first condition to work */ 139 ip += bitCount>>3; 140 bitCount &= 7; 141 } else { 142 bitCount -= (int)(8 * (iend - 4 - ip)); 143 bitCount &= 31; 144 ip = iend - 4; 145 } 146 bitStream = MEM_readLE32(ip) >> bitCount; 147 } 148 { 149 int const max = (2*threshold-1) - remaining; 150 int count; 151 152 if ((bitStream & (threshold-1)) < (U32)max) { 153 count = bitStream & (threshold-1); 154 bitCount += nbBits-1; 155 } else { 156 count = bitStream & (2*threshold-1); 157 if (count >= threshold) count -= max; 158 bitCount += nbBits; 159 } 160 161 count--; /* extra accuracy */ 162 /* When it matters (small blocks), this is a 163 * predictable branch, because we don't use -1. 164 */ 165 if (count >= 0) { 166 remaining -= count; 167 } else { 168 assert(count == -1); 169 remaining += count; 170 } 171 normalizedCounter[charnum++] = (short)count; 172 previous0 = !count; 173 174 assert(threshold > 1); 175 if (remaining < threshold) { 176 /* This branch can be folded into the 177 * threshold update condition because we 178 * know that threshold > 1. 179 */ 180 if (remaining <= 1) break; 181 nbBits = BIT_highbit32(remaining) + 1; 182 threshold = 1 << (nbBits - 1); 183 } 184 if (charnum >= maxSV1) break; 185 186 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 187 ip += bitCount>>3; 188 bitCount &= 7; 189 } else { 190 bitCount -= (int)(8 * (iend - 4 - ip)); 191 bitCount &= 31; 192 ip = iend - 4; 193 } 194 bitStream = MEM_readLE32(ip) >> bitCount; 195 } } 196 if (remaining != 1) return ERROR(corruption_detected); 197 /* Only possible when there are too many zeros. */ 198 if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); 199 if (bitCount > 32) return ERROR(corruption_detected); 200 *maxSVPtr = charnum-1; 201 202 ip += (bitCount+7)>>3; 203 return ip-istart; 204} 205 206/* Avoids the FORCE_INLINE of the _body() function. */ 207static size_t FSE_readNCount_body_default( 208 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 209 const void* headerBuffer, size_t hbSize) 210{ 211 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 212} 213 214#if DYNAMIC_BMI2 215TARGET_ATTRIBUTE("bmi2") static size_t FSE_readNCount_body_bmi2( 216 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 217 const void* headerBuffer, size_t hbSize) 218{ 219 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 220} 221#endif 222 223size_t FSE_readNCount_bmi2( 224 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 225 const void* headerBuffer, size_t hbSize, int bmi2) 226{ 227#if DYNAMIC_BMI2 228 if (bmi2) { 229 return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 230 } 231#endif 232 (void)bmi2; 233 return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 234} 235 236size_t FSE_readNCount( 237 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 238 const void* headerBuffer, size_t hbSize) 239{ 240 return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); 241} 242 243 244/*! HUF_readStats() : 245 Read compact Huffman tree, saved by HUF_writeCTable(). 246 `huffWeight` is destination buffer. 247 `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. 248 @return : size read from `src` , or an error Code . 249 Note : Needed by HUF_readCTable() and HUF_readDTableX?() . 250*/ 251size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, 252 U32* nbSymbolsPtr, U32* tableLogPtr, 253 const void* src, size_t srcSize) 254{ 255 U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; 256 return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* bmi2 */ 0); 257} 258 259FORCE_INLINE_TEMPLATE size_t 260HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, 261 U32* nbSymbolsPtr, U32* tableLogPtr, 262 const void* src, size_t srcSize, 263 void* workSpace, size_t wkspSize, 264 int bmi2) 265{ 266 U32 weightTotal; 267 const BYTE* ip = (const BYTE*) src; 268 size_t iSize; 269 size_t oSize; 270 271 if (!srcSize) return ERROR(srcSize_wrong); 272 iSize = ip[0]; 273 /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ 274 275 if (iSize >= 128) { /* special header */ 276 oSize = iSize - 127; 277 iSize = ((oSize+1)/2); 278 if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 279 if (oSize >= hwSize) return ERROR(corruption_detected); 280 ip += 1; 281 { U32 n; 282 for (n=0; n<oSize; n+=2) { 283 huffWeight[n] = ip[n/2] >> 4; 284 huffWeight[n+1] = ip[n/2] & 15; 285 } } } 286 else { /* header compressed with FSE (normal case) */ 287 if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 288 /* max (hwSize-1) values decoded, as last one is implied */ 289 oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); 290 if (FSE_isError(oSize)) return oSize; 291 } 292 293 /* collect weight stats */ 294 ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); 295 weightTotal = 0; 296 { U32 n; for (n=0; n<oSize; n++) { 297 if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected); 298 rankStats[huffWeight[n]]++; 299 weightTotal += (1 << huffWeight[n]) >> 1; 300 } } 301 if (weightTotal == 0) return ERROR(corruption_detected); 302 303 /* get last non-null symbol weight (implied, total must be 2^n) */ 304 { U32 const tableLog = BIT_highbit32(weightTotal) + 1; 305 if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 306 *tableLogPtr = tableLog; 307 /* determine last weight */ 308 { U32 const total = 1 << tableLog; 309 U32 const rest = total - weightTotal; 310 U32 const verif = 1 << BIT_highbit32(rest); 311 U32 const lastWeight = BIT_highbit32(rest) + 1; 312 if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ 313 huffWeight[oSize] = (BYTE)lastWeight; 314 rankStats[lastWeight]++; 315 } } 316 317 /* check tree construction validity */ 318 if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ 319 320 /* results */ 321 *nbSymbolsPtr = (U32)(oSize+1); 322 return iSize+1; 323} 324 325/* Avoids the FORCE_INLINE of the _body() function. */ 326static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, 327 U32* nbSymbolsPtr, U32* tableLogPtr, 328 const void* src, size_t srcSize, 329 void* workSpace, size_t wkspSize) 330{ 331 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); 332} 333 334#if DYNAMIC_BMI2 335static TARGET_ATTRIBUTE("bmi2") size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, 336 U32* nbSymbolsPtr, U32* tableLogPtr, 337 const void* src, size_t srcSize, 338 void* workSpace, size_t wkspSize) 339{ 340 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); 341} 342#endif 343 344size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, 345 U32* nbSymbolsPtr, U32* tableLogPtr, 346 const void* src, size_t srcSize, 347 void* workSpace, size_t wkspSize, 348 int bmi2) 349{ 350#if DYNAMIC_BMI2 351 if (bmi2) { 352 return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 353 } 354#endif 355 (void)bmi2; 356 return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 357}