151786Sdcs/*
251786Sdcs * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc.
351786Sdcs * All rights reserved.
451786Sdcs *
551786Sdcs * This source code is licensed under both the BSD-style license (found in the
651786Sdcs * LICENSE file in the root directory of this source tree) and the GPLv2 (found
751786Sdcs * in the COPYING file in the root directory of this source tree).
851786Sdcs * You may select, at your option, one of the above-listed licenses.
951786Sdcs */
1051786Sdcs
1151786Sdcs /*-*************************************
1251786Sdcs *  Dependencies
1351786Sdcs ***************************************/
1451786Sdcs#include "zstd_compress_sequences.h"
1551786Sdcs
1651786Sdcs/**
1751786Sdcs * -log2(x / 256) lookup table for x in [0, 256).
1851786Sdcs * If x == 0: Return 0
1951786Sdcs * Else: Return floor(-log2(x / 256) * 256)
2051786Sdcs */
2151786Sdcsstatic unsigned const kInverseProbabilityLog256[256] = {
2251786Sdcs    0,    2048, 1792, 1642, 1536, 1453, 1386, 1329, 1280, 1236, 1197, 1162,
2351786Sdcs    1130, 1100, 1073, 1047, 1024, 1001, 980,  960,  941,  923,  906,  889,
2451786Sdcs    874,  859,  844,  830,  817,  804,  791,  779,  768,  756,  745,  734,
2551786Sdcs    724,  714,  704,  694,  685,  676,  667,  658,  650,  642,  633,  626,
2651786Sdcs    618,  610,  603,  595,  588,  581,  574,  567,  561,  554,  548,  542,
2751786Sdcs    535,  529,  523,  517,  512,  506,  500,  495,  489,  484,  478,  473,
2851786Sdcs    468,  463,  458,  453,  448,  443,  438,  434,  429,  424,  420,  415,
2951786Sdcs    411,  407,  402,  398,  394,  390,  386,  382,  377,  373,  370,  366,
3051786Sdcs    362,  358,  354,  350,  347,  343,  339,  336,  332,  329,  325,  322,
3151786Sdcs    318,  315,  311,  308,  305,  302,  298,  295,  292,  289,  286,  282,
3251786Sdcs    279,  276,  273,  270,  267,  264,  261,  258,  256,  253,  250,  247,
3351786Sdcs    244,  241,  239,  236,  233,  230,  228,  225,  222,  220,  217,  215,
3451786Sdcs    212,  209,  207,  204,  202,  199,  197,  194,  192,  190,  187,  185,
3551786Sdcs    182,  180,  178,  175,  173,  171,  168,  166,  164,  162,  159,  157,
3651786Sdcs    155,  153,  151,  149,  146,  144,  142,  140,  138,  136,  134,  132,
3751786Sdcs    130,  128,  126,  123,  121,  119,  117,  115,  114,  112,  110,  108,
3851786Sdcs    106,  104,  102,  100,  98,   96,   94,   93,   91,   89,   87,   85,
3951786Sdcs    83,   82,   80,   78,   76,   74,   73,   71,   69,   67,   66,   64,
4051786Sdcs    62,   61,   59,   57,   55,   54,   52,   50,   49,   47,   46,   44,
4151786Sdcs    42,   41,   39,   37,   36,   34,   33,   31,   30,   28,   26,   25,
4251786Sdcs    23,   22,   20,   19,   17,   16,   14,   13,   11,   10,   8,    7,
4351786Sdcs    5,    4,    2,    1,
4451786Sdcs};
4551786Sdcs
4651786Sdcsstatic unsigned ZSTD_getFSEMaxSymbolValue(FSE_CTable const* ctable) {
4751786Sdcs  void const* ptr = ctable;
4851786Sdcs  U16 const* u16ptr = (U16 const*)ptr;
4951786Sdcs  U32 const maxSymbolValue = MEM_read16(u16ptr + 1);
50  return maxSymbolValue;
51}
52
53/**
54 * Returns the cost in bytes of encoding the normalized count header.
55 * Returns an error if any of the helper functions return an error.
56 */
57static size_t ZSTD_NCountCost(unsigned const* count, unsigned const max,
58                              size_t const nbSeq, unsigned const FSELog)
59{
60    BYTE wksp[FSE_NCOUNTBOUND];
61    S16 norm[MaxSeq + 1];
62    const U32 tableLog = FSE_optimalTableLog(FSELog, nbSeq, max);
63    FORWARD_IF_ERROR(FSE_normalizeCount(norm, tableLog, count, nbSeq, max), "");
64    return FSE_writeNCount(wksp, sizeof(wksp), norm, max, tableLog);
65}
66
67/**
68 * Returns the cost in bits of encoding the distribution described by count
69 * using the entropy bound.
70 */
71static size_t ZSTD_entropyCost(unsigned const* count, unsigned const max, size_t const total)
72{
73    unsigned cost = 0;
74    unsigned s;
75    for (s = 0; s <= max; ++s) {
76        unsigned norm = (unsigned)((256 * count[s]) / total);
77        if (count[s] != 0 && norm == 0)
78            norm = 1;
79        assert(count[s] < total);
80        cost += count[s] * kInverseProbabilityLog256[norm];
81    }
82    return cost >> 8;
83}
84
85/**
86 * Returns the cost in bits of encoding the distribution in count using ctable.
87 * Returns an error if ctable cannot represent all the symbols in count.
88 */
89size_t ZSTD_fseBitCost(
90    FSE_CTable const* ctable,
91    unsigned const* count,
92    unsigned const max)
93{
94    unsigned const kAccuracyLog = 8;
95    size_t cost = 0;
96    unsigned s;
97    FSE_CState_t cstate;
98    FSE_initCState(&cstate, ctable);
99    if (ZSTD_getFSEMaxSymbolValue(ctable) < max) {
100        DEBUGLOG(5, "Repeat FSE_CTable has maxSymbolValue %u < %u",
101                    ZSTD_getFSEMaxSymbolValue(ctable), max);
102        return ERROR(GENERIC);
103    }
104    for (s = 0; s <= max; ++s) {
105        unsigned const tableLog = cstate.stateLog;
106        unsigned const badCost = (tableLog + 1) << kAccuracyLog;
107        unsigned const bitCost = FSE_bitCost(cstate.symbolTT, tableLog, s, kAccuracyLog);
108        if (count[s] == 0)
109            continue;
110        if (bitCost >= badCost) {
111            DEBUGLOG(5, "Repeat FSE_CTable has Prob[%u] == 0", s);
112            return ERROR(GENERIC);
113        }
114        cost += (size_t)count[s] * bitCost;
115    }
116    return cost >> kAccuracyLog;
117}
118
119/**
120 * Returns the cost in bits of encoding the distribution in count using the
121 * table described by norm. The max symbol support by norm is assumed >= max.
122 * norm must be valid for every symbol with non-zero probability in count.
123 */
124size_t ZSTD_crossEntropyCost(short const* norm, unsigned accuracyLog,
125                             unsigned const* count, unsigned const max)
126{
127    unsigned const shift = 8 - accuracyLog;
128    size_t cost = 0;
129    unsigned s;
130    assert(accuracyLog <= 8);
131    for (s = 0; s <= max; ++s) {
132        unsigned const normAcc = (norm[s] != -1) ? (unsigned)norm[s] : 1;
133        unsigned const norm256 = normAcc << shift;
134        assert(norm256 > 0);
135        assert(norm256 < 256);
136        cost += count[s] * kInverseProbabilityLog256[norm256];
137    }
138    return cost >> 8;
139}
140
141symbolEncodingType_e
142ZSTD_selectEncodingType(
143        FSE_repeat* repeatMode, unsigned const* count, unsigned const max,
144        size_t const mostFrequent, size_t nbSeq, unsigned const FSELog,
145        FSE_CTable const* prevCTable,
146        short const* defaultNorm, U32 defaultNormLog,
147        ZSTD_defaultPolicy_e const isDefaultAllowed,
148        ZSTD_strategy const strategy)
149{
150    ZSTD_STATIC_ASSERT(ZSTD_defaultDisallowed == 0 && ZSTD_defaultAllowed != 0);
151    if (mostFrequent == nbSeq) {
152        *repeatMode = FSE_repeat_none;
153        if (isDefaultAllowed && nbSeq <= 2) {
154            /* Prefer set_basic over set_rle when there are 2 or less symbols,
155             * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
156             * If basic encoding isn't possible, always choose RLE.
157             */
158            DEBUGLOG(5, "Selected set_basic");
159            return set_basic;
160        }
161        DEBUGLOG(5, "Selected set_rle");
162        return set_rle;
163    }
164    if (strategy < ZSTD_lazy) {
165        if (isDefaultAllowed) {
166            size_t const staticFse_nbSeq_max = 1000;
167            size_t const mult = 10 - strategy;
168            size_t const baseLog = 3;
169            size_t const dynamicFse_nbSeq_min = (((size_t)1 << defaultNormLog) * mult) >> baseLog;  /* 28-36 for offset, 56-72 for lengths */
170            assert(defaultNormLog >= 5 && defaultNormLog <= 6);  /* xx_DEFAULTNORMLOG */
171            assert(mult <= 9 && mult >= 7);
172            if ( (*repeatMode == FSE_repeat_valid)
173              && (nbSeq < staticFse_nbSeq_max) ) {
174                DEBUGLOG(5, "Selected set_repeat");
175                return set_repeat;
176            }
177            if ( (nbSeq < dynamicFse_nbSeq_min)
178              || (mostFrequent < (nbSeq >> (defaultNormLog-1))) ) {
179                DEBUGLOG(5, "Selected set_basic");
180                /* The format allows default tables to be repeated, but it isn't useful.
181                 * When using simple heuristics to select encoding type, we don't want
182                 * to confuse these tables with dictionaries. When running more careful
183                 * analysis, we don't need to waste time checking both repeating tables
184                 * and default tables.
185                 */
186                *repeatMode = FSE_repeat_none;
187                return set_basic;
188            }
189        }
190    } else {
191        size_t const basicCost = isDefaultAllowed ? ZSTD_crossEntropyCost(defaultNorm, defaultNormLog, count, max) : ERROR(GENERIC);
192        size_t const repeatCost = *repeatMode != FSE_repeat_none ? ZSTD_fseBitCost(prevCTable, count, max) : ERROR(GENERIC);
193        size_t const NCountCost = ZSTD_NCountCost(count, max, nbSeq, FSELog);
194        size_t const compressedCost = (NCountCost << 3) + ZSTD_entropyCost(count, max, nbSeq);
195
196        if (isDefaultAllowed) {
197            assert(!ZSTD_isError(basicCost));
198            assert(!(*repeatMode == FSE_repeat_valid && ZSTD_isError(repeatCost)));
199        }
200        assert(!ZSTD_isError(NCountCost));
201        assert(compressedCost < ERROR(maxCode));
202        DEBUGLOG(5, "Estimated bit costs: basic=%u\trepeat=%u\tcompressed=%u",
203                    (unsigned)basicCost, (unsigned)repeatCost, (unsigned)compressedCost);
204        if (basicCost <= repeatCost && basicCost <= compressedCost) {
205            DEBUGLOG(5, "Selected set_basic");
206            assert(isDefaultAllowed);
207            *repeatMode = FSE_repeat_none;
208            return set_basic;
209        }
210        if (repeatCost <= compressedCost) {
211            DEBUGLOG(5, "Selected set_repeat");
212            assert(!ZSTD_isError(repeatCost));
213            return set_repeat;
214        }
215        assert(compressedCost < basicCost && compressedCost < repeatCost);
216    }
217    DEBUGLOG(5, "Selected set_compressed");
218    *repeatMode = FSE_repeat_check;
219    return set_compressed;
220}
221
222size_t
223ZSTD_buildCTable(void* dst, size_t dstCapacity,
224                FSE_CTable* nextCTable, U32 FSELog, symbolEncodingType_e type,
225                unsigned* count, U32 max,
226                const BYTE* codeTable, size_t nbSeq,
227                const S16* defaultNorm, U32 defaultNormLog, U32 defaultMax,
228                const FSE_CTable* prevCTable, size_t prevCTableSize,
229                void* entropyWorkspace, size_t entropyWorkspaceSize)
230{
231    BYTE* op = (BYTE*)dst;
232    const BYTE* const oend = op + dstCapacity;
233    DEBUGLOG(6, "ZSTD_buildCTable (dstCapacity=%u)", (unsigned)dstCapacity);
234
235    switch (type) {
236    case set_rle:
237        FORWARD_IF_ERROR(FSE_buildCTable_rle(nextCTable, (BYTE)max), "");
238        RETURN_ERROR_IF(dstCapacity==0, dstSize_tooSmall, "not enough space");
239        *op = codeTable[0];
240        return 1;
241    case set_repeat:
242        memcpy(nextCTable, prevCTable, prevCTableSize);
243        return 0;
244    case set_basic:
245        FORWARD_IF_ERROR(FSE_buildCTable_wksp(nextCTable, defaultNorm, defaultMax, defaultNormLog, entropyWorkspace, entropyWorkspaceSize), "");  /* note : could be pre-calculated */
246        return 0;
247    case set_compressed: {
248        S16 norm[MaxSeq + 1];
249        size_t nbSeq_1 = nbSeq;
250        const U32 tableLog = FSE_optimalTableLog(FSELog, nbSeq, max);
251        if (count[codeTable[nbSeq-1]] > 1) {
252            count[codeTable[nbSeq-1]]--;
253            nbSeq_1--;
254        }
255        assert(nbSeq_1 > 1);
256        FORWARD_IF_ERROR(FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max), "");
257        {   size_t const NCountSize = FSE_writeNCount(op, oend - op, norm, max, tableLog);   /* overflow protected */
258            FORWARD_IF_ERROR(NCountSize, "FSE_writeNCount failed");
259            FORWARD_IF_ERROR(FSE_buildCTable_wksp(nextCTable, norm, max, tableLog, entropyWorkspace, entropyWorkspaceSize), "");
260            return NCountSize;
261        }
262    }
263    default: assert(0); RETURN_ERROR(GENERIC, "impossible to reach");
264    }
265}
266
267FORCE_INLINE_TEMPLATE size_t
268ZSTD_encodeSequences_body(
269            void* dst, size_t dstCapacity,
270            FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable,
271            FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable,
272            FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable,
273            seqDef const* sequences, size_t nbSeq, int longOffsets)
274{
275    BIT_CStream_t blockStream;
276    FSE_CState_t  stateMatchLength;
277    FSE_CState_t  stateOffsetBits;
278    FSE_CState_t  stateLitLength;
279
280    RETURN_ERROR_IF(
281        ERR_isError(BIT_initCStream(&blockStream, dst, dstCapacity)),
282        dstSize_tooSmall, "not enough space remaining");
283    DEBUGLOG(6, "available space for bitstream : %i  (dstCapacity=%u)",
284                (int)(blockStream.endPtr - blockStream.startPtr),
285                (unsigned)dstCapacity);
286
287    /* first symbols */
288    FSE_initCState2(&stateMatchLength, CTable_MatchLength, mlCodeTable[nbSeq-1]);
289    FSE_initCState2(&stateOffsetBits,  CTable_OffsetBits,  ofCodeTable[nbSeq-1]);
290    FSE_initCState2(&stateLitLength,   CTable_LitLength,   llCodeTable[nbSeq-1]);
291    BIT_addBits(&blockStream, sequences[nbSeq-1].litLength, LL_bits[llCodeTable[nbSeq-1]]);
292    if (MEM_32bits()) BIT_flushBits(&blockStream);
293    BIT_addBits(&blockStream, sequences[nbSeq-1].matchLength, ML_bits[mlCodeTable[nbSeq-1]]);
294    if (MEM_32bits()) BIT_flushBits(&blockStream);
295    if (longOffsets) {
296        U32 const ofBits = ofCodeTable[nbSeq-1];
297        unsigned const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN-1);
298        if (extraBits) {
299            BIT_addBits(&blockStream, sequences[nbSeq-1].offset, extraBits);
300            BIT_flushBits(&blockStream);
301        }
302        BIT_addBits(&blockStream, sequences[nbSeq-1].offset >> extraBits,
303                    ofBits - extraBits);
304    } else {
305        BIT_addBits(&blockStream, sequences[nbSeq-1].offset, ofCodeTable[nbSeq-1]);
306    }
307    BIT_flushBits(&blockStream);
308
309    {   size_t n;
310        for (n=nbSeq-2 ; n<nbSeq ; n--) {      /* intentional underflow */
311            BYTE const llCode = llCodeTable[n];
312            BYTE const ofCode = ofCodeTable[n];
313            BYTE const mlCode = mlCodeTable[n];
314            U32  const llBits = LL_bits[llCode];
315            U32  const ofBits = ofCode;
316            U32  const mlBits = ML_bits[mlCode];
317            DEBUGLOG(6, "encoding: litlen:%2u - matchlen:%2u - offCode:%7u",
318                        (unsigned)sequences[n].litLength,
319                        (unsigned)sequences[n].matchLength + MINMATCH,
320                        (unsigned)sequences[n].offset);
321                                                                            /* 32b*/  /* 64b*/
322                                                                            /* (7)*/  /* (7)*/
323            FSE_encodeSymbol(&blockStream, &stateOffsetBits, ofCode);       /* 15 */  /* 15 */
324            FSE_encodeSymbol(&blockStream, &stateMatchLength, mlCode);      /* 24 */  /* 24 */
325            if (MEM_32bits()) BIT_flushBits(&blockStream);                  /* (7)*/
326            FSE_encodeSymbol(&blockStream, &stateLitLength, llCode);        /* 16 */  /* 33 */
327            if (MEM_32bits() || (ofBits+mlBits+llBits >= 64-7-(LLFSELog+MLFSELog+OffFSELog)))
328                BIT_flushBits(&blockStream);                                /* (7)*/
329            BIT_addBits(&blockStream, sequences[n].litLength, llBits);
330            if (MEM_32bits() && ((llBits+mlBits)>24)) BIT_flushBits(&blockStream);
331            BIT_addBits(&blockStream, sequences[n].matchLength, mlBits);
332            if (MEM_32bits() || (ofBits+mlBits+llBits > 56)) BIT_flushBits(&blockStream);
333            if (longOffsets) {
334                unsigned const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN-1);
335                if (extraBits) {
336                    BIT_addBits(&blockStream, sequences[n].offset, extraBits);
337                    BIT_flushBits(&blockStream);                            /* (7)*/
338                }
339                BIT_addBits(&blockStream, sequences[n].offset >> extraBits,
340                            ofBits - extraBits);                            /* 31 */
341            } else {
342                BIT_addBits(&blockStream, sequences[n].offset, ofBits);     /* 31 */
343            }
344            BIT_flushBits(&blockStream);                                    /* (7)*/
345            DEBUGLOG(7, "remaining space : %i", (int)(blockStream.endPtr - blockStream.ptr));
346    }   }
347
348    DEBUGLOG(6, "ZSTD_encodeSequences: flushing ML state with %u bits", stateMatchLength.stateLog);
349    FSE_flushCState(&blockStream, &stateMatchLength);
350    DEBUGLOG(6, "ZSTD_encodeSequences: flushing Off state with %u bits", stateOffsetBits.stateLog);
351    FSE_flushCState(&blockStream, &stateOffsetBits);
352    DEBUGLOG(6, "ZSTD_encodeSequences: flushing LL state with %u bits", stateLitLength.stateLog);
353    FSE_flushCState(&blockStream, &stateLitLength);
354
355    {   size_t const streamSize = BIT_closeCStream(&blockStream);
356        RETURN_ERROR_IF(streamSize==0, dstSize_tooSmall, "not enough space");
357        return streamSize;
358    }
359}
360
361static size_t
362ZSTD_encodeSequences_default(
363            void* dst, size_t dstCapacity,
364            FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable,
365            FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable,
366            FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable,
367            seqDef const* sequences, size_t nbSeq, int longOffsets)
368{
369    return ZSTD_encodeSequences_body(dst, dstCapacity,
370                                    CTable_MatchLength, mlCodeTable,
371                                    CTable_OffsetBits, ofCodeTable,
372                                    CTable_LitLength, llCodeTable,
373                                    sequences, nbSeq, longOffsets);
374}
375
376
377#if DYNAMIC_BMI2
378
379static TARGET_ATTRIBUTE("bmi2") size_t
380ZSTD_encodeSequences_bmi2(
381            void* dst, size_t dstCapacity,
382            FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable,
383            FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable,
384            FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable,
385            seqDef const* sequences, size_t nbSeq, int longOffsets)
386{
387    return ZSTD_encodeSequences_body(dst, dstCapacity,
388                                    CTable_MatchLength, mlCodeTable,
389                                    CTable_OffsetBits, ofCodeTable,
390                                    CTable_LitLength, llCodeTable,
391                                    sequences, nbSeq, longOffsets);
392}
393
394#endif
395
396size_t ZSTD_encodeSequences(
397            void* dst, size_t dstCapacity,
398            FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable,
399            FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable,
400            FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable,
401            seqDef const* sequences, size_t nbSeq, int longOffsets, int bmi2)
402{
403    DEBUGLOG(5, "ZSTD_encodeSequences: dstCapacity = %u", (unsigned)dstCapacity);
404#if DYNAMIC_BMI2
405    if (bmi2) {
406        return ZSTD_encodeSequences_bmi2(dst, dstCapacity,
407                                         CTable_MatchLength, mlCodeTable,
408                                         CTable_OffsetBits, ofCodeTable,
409                                         CTable_LitLength, llCodeTable,
410                                         sequences, nbSeq, longOffsets);
411    }
412#endif
413    (void)bmi2;
414    return ZSTD_encodeSequences_default(dst, dstCapacity,
415                                        CTable_MatchLength, mlCodeTable,
416                                        CTable_OffsetBits, ofCodeTable,
417                                        CTable_LitLength, llCodeTable,
418                                        sequences, nbSeq, longOffsets);
419}
420