1/*
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 */
9
10/// Zstandard educational decoder implementation
11/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
12
13#include <stdint.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <string.h>
17#include "zstd_decompress.h"
18
19/******* UTILITY MACROS AND TYPES *********************************************/
20// Max block size decompressed size is 128 KB and literal blocks can't be
21// larger than their block
22#define MAX_LITERALS_SIZE ((size_t)128 * 1024)
23
24#define MAX(a, b) ((a) > (b) ? (a) : (b))
25#define MIN(a, b) ((a) < (b) ? (a) : (b))
26
27/// This decoder calls exit(1) when it encounters an error, however a production
28/// library should propagate error codes
29#define ERROR(s)                                                               \
30    do {                                                                       \
31        fprintf(stderr, "Error: %s\n", s);                                     \
32        exit(1);                                                               \
33    } while (0)
34#define INP_SIZE()                                                             \
35    ERROR("Input buffer smaller than it should be or input is "                \
36          "corrupted")
37#define OUT_SIZE() ERROR("Output buffer too small for output")
38#define CORRUPTION() ERROR("Corruption detected while decompressing")
39#define BAD_ALLOC() ERROR("Memory allocation error")
40#define IMPOSSIBLE() ERROR("An impossibility has occurred")
41
42typedef uint8_t u8;
43typedef uint16_t u16;
44typedef uint32_t u32;
45typedef uint64_t u64;
46
47typedef int8_t i8;
48typedef int16_t i16;
49typedef int32_t i32;
50typedef int64_t i64;
51/******* END UTILITY MACROS AND TYPES *****************************************/
52
53/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
54/// The implementations for these functions can be found at the bottom of this
55/// file.  They implement low-level functionality needed for the higher level
56/// decompression functions.
57
58/*** IO STREAM OPERATIONS *************/
59
60/// ostream_t/istream_t are used to wrap the pointers/length data passed into
61/// ZSTD_decompress, so that all IO operations are safely bounds checked
62/// They are written/read forward, and reads are treated as little-endian
63/// They should be used opaquely to ensure safety
64typedef struct {
65    u8 *ptr;
66    size_t len;
67} ostream_t;
68
69typedef struct {
70    const u8 *ptr;
71    size_t len;
72
73    // Input often reads a few bits at a time, so maintain an internal offset
74    int bit_offset;
75} istream_t;
76
77/// The following two functions are the only ones that allow the istream to be
78/// non-byte aligned
79
80/// Reads `num` bits from a bitstream, and updates the internal offset
81static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
82/// Backs-up the stream by `num` bits so they can be read again
83static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
84/// If the remaining bits in a byte will be unused, advance to the end of the
85/// byte
86static inline void IO_align_stream(istream_t *const in);
87
88/// Write the given byte into the output stream
89static inline void IO_write_byte(ostream_t *const out, u8 symb);
90
91/// Returns the number of bytes left to be read in this stream.  The stream must
92/// be byte aligned.
93static inline size_t IO_istream_len(const istream_t *const in);
94
95/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
96/// was skipped.  The stream must be byte aligned.
97static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
98/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
99/// was skipped so it can be written to.
100static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
101
102/// Advance the inner state by `len` bytes.  The stream must be byte aligned.
103static inline void IO_advance_input(istream_t *const in, size_t len);
104
105/// Returns an `ostream_t` constructed from the given pointer and length.
106static inline ostream_t IO_make_ostream(u8 *out, size_t len);
107/// Returns an `istream_t` constructed from the given pointer and length.
108static inline istream_t IO_make_istream(const u8 *in, size_t len);
109
110/// Returns an `istream_t` with the same base as `in`, and length `len`.
111/// Then, advance `in` to account for the consumed bytes.
112/// `in` must be byte aligned.
113static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
114/*** END IO STREAM OPERATIONS *********/
115
116/*** BITSTREAM OPERATIONS *************/
117/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
118/// and return them interpreted as a little-endian unsigned integer.
119static inline u64 read_bits_LE(const u8 *src, const int num_bits,
120                               const size_t offset);
121
122/// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
123/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
124/// `src + offset`.  If the offset becomes negative, the extra bits at the
125/// bottom are filled in with `0` bits instead of reading from before `src`.
126static inline u64 STREAM_read_bits(const u8 *src, const int bits,
127                                   i64 *const offset);
128/*** END BITSTREAM OPERATIONS *********/
129
130/*** BIT COUNTING OPERATIONS **********/
131/// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
132static inline int highest_set_bit(const u64 num);
133/*** END BIT COUNTING OPERATIONS ******/
134
135/*** HUFFMAN PRIMITIVES ***************/
136// Table decode method uses exponential memory, so we need to limit depth
137#define HUF_MAX_BITS (16)
138
139// Limit the maximum number of symbols to 256 so we can store a symbol in a byte
140#define HUF_MAX_SYMBS (256)
141
142/// Structure containing all tables necessary for efficient Huffman decoding
143typedef struct {
144    u8 *symbols;
145    u8 *num_bits;
146    int max_bits;
147} HUF_dtable;
148
149/// Decode a single symbol and read in enough bits to refresh the state
150static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
151                                   u16 *const state, const u8 *const src,
152                                   i64 *const offset);
153/// Read in a full state's worth of bits to initialize it
154static inline void HUF_init_state(const HUF_dtable *const dtable,
155                                  u16 *const state, const u8 *const src,
156                                  i64 *const offset);
157
158/// Decompresses a single Huffman stream, returns the number of bytes decoded.
159/// `src_len` must be the exact length of the Huffman-coded block.
160static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
161                                     ostream_t *const out, istream_t *const in);
162/// Same as previous but decodes 4 streams, formatted as in the Zstandard
163/// specification.
164/// `src_len` must be the exact length of the Huffman-coded block.
165static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
166                                     ostream_t *const out, istream_t *const in);
167
168/// Initialize a Huffman decoding table using the table of bit counts provided
169static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
170                            const int num_symbs);
171/// Initialize a Huffman decoding table using the table of weights provided
172/// Weights follow the definition provided in the Zstandard specification
173static void HUF_init_dtable_usingweights(HUF_dtable *const table,
174                                         const u8 *const weights,
175                                         const int num_symbs);
176
177/// Free the malloc'ed parts of a decoding table
178static void HUF_free_dtable(HUF_dtable *const dtable);
179
180/// Deep copy a decoding table, so that it can be used and free'd without
181/// impacting the source table.
182static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src);
183/*** END HUFFMAN PRIMITIVES ***********/
184
185/*** FSE PRIMITIVES *******************/
186/// For more description of FSE see
187/// https://github.com/Cyan4973/FiniteStateEntropy/
188
189// FSE table decoding uses exponential memory, so limit the maximum accuracy
190#define FSE_MAX_ACCURACY_LOG (15)
191// Limit the maximum number of symbols so they can be stored in a single byte
192#define FSE_MAX_SYMBS (256)
193
194/// The tables needed to decode FSE encoded streams
195typedef struct {
196    u8 *symbols;
197    u8 *num_bits;
198    u16 *new_state_base;
199    int accuracy_log;
200} FSE_dtable;
201
202/// Return the symbol for the current state
203static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
204                                 const u16 state);
205/// Read the number of bits necessary to update state, update, and shift offset
206/// back to reflect the bits read
207static inline void FSE_update_state(const FSE_dtable *const dtable,
208                                    u16 *const state, const u8 *const src,
209                                    i64 *const offset);
210
211/// Combine peek and update: decode a symbol and update the state
212static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
213                                   u16 *const state, const u8 *const src,
214                                   i64 *const offset);
215
216/// Read bits from the stream to initialize the state and shift offset back
217static inline void FSE_init_state(const FSE_dtable *const dtable,
218                                  u16 *const state, const u8 *const src,
219                                  i64 *const offset);
220
221/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
222/// using an FSE decoding table.  `src_len` must be the exact length of the
223/// block.
224static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
225                                          ostream_t *const out,
226                                          istream_t *const in);
227
228/// Initialize a decoding table using normalized frequencies.
229static void FSE_init_dtable(FSE_dtable *const dtable,
230                            const i16 *const norm_freqs, const int num_symbs,
231                            const int accuracy_log);
232
233/// Decode an FSE header as defined in the Zstandard format specification and
234/// use the decoded frequencies to initialize a decoding table.
235static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
236                                const int max_accuracy_log);
237
238/// Initialize an FSE table that will always return the same symbol and consume
239/// 0 bits per symbol, to be used for RLE mode in sequence commands
240static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
241
242/// Free the malloc'ed parts of a decoding table
243static void FSE_free_dtable(FSE_dtable *const dtable);
244
245/// Deep copy a decoding table, so that it can be used and free'd without
246/// impacting the source table.
247static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
248/*** END FSE PRIMITIVES ***************/
249
250/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
251
252/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
253
254/// A small structure that can be reused in various places that need to access
255/// frame header information
256typedef struct {
257    // The size of window that we need to be able to contiguously store for
258    // references
259    size_t window_size;
260    // The total output size of this compressed frame
261    size_t frame_content_size;
262
263    // The dictionary id if this frame uses one
264    u32 dictionary_id;
265
266    // Whether or not the content of this frame has a checksum
267    int content_checksum_flag;
268    // Whether or not the output for this frame is in a single segment
269    int single_segment_flag;
270} frame_header_t;
271
272/// The context needed to decode blocks in a frame
273typedef struct {
274    frame_header_t header;
275
276    // The total amount of data available for backreferences, to determine if an
277    // offset too large to be correct
278    size_t current_total_output;
279
280    const u8 *dict_content;
281    size_t dict_content_len;
282
283    // Entropy encoding tables so they can be repeated by future blocks instead
284    // of retransmitting
285    HUF_dtable literals_dtable;
286    FSE_dtable ll_dtable;
287    FSE_dtable ml_dtable;
288    FSE_dtable of_dtable;
289
290    // The last 3 offsets for the special "repeat offsets".
291    u64 previous_offsets[3];
292} frame_context_t;
293
294/// The decoded contents of a dictionary so that it doesn't have to be repeated
295/// for each frame that uses it
296struct dictionary_s {
297    // Entropy tables
298    HUF_dtable literals_dtable;
299    FSE_dtable ll_dtable;
300    FSE_dtable ml_dtable;
301    FSE_dtable of_dtable;
302
303    // Raw content for backreferences
304    u8 *content;
305    size_t content_size;
306
307    // Offset history to prepopulate the frame's history
308    u64 previous_offsets[3];
309
310    u32 dictionary_id;
311};
312
313/// A tuple containing the parts necessary to decode and execute a ZSTD sequence
314/// command
315typedef struct {
316    u32 literal_length;
317    u32 match_length;
318    u32 offset;
319} sequence_command_t;
320
321/// The decoder works top-down, starting at the high level like Zstd frames, and
322/// working down to lower more technical levels such as blocks, literals, and
323/// sequences.  The high-level functions roughly follow the outline of the
324/// format specification:
325/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
326
327/// Before the implementation of each high-level function declared here, the
328/// prototypes for their helper functions are defined and explained
329
330/// Decode a single Zstd frame, or error if the input is not a valid frame.
331/// Accepts a dict argument, which may be NULL indicating no dictionary.
332/// See
333/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
334static void decode_frame(ostream_t *const out, istream_t *const in,
335                         const dictionary_t *const dict);
336
337// Decode data in a compressed block
338static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
339                             istream_t *const in);
340
341// Decode the literals section of a block
342static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
343                              u8 **const literals);
344
345// Decode the sequences part of a block
346static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
347                               sequence_command_t **const sequences);
348
349// Execute the decoded sequences on the literals block
350static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
351                              const u8 *const literals,
352                              const size_t literals_len,
353                              const sequence_command_t *const sequences,
354                              const size_t num_sequences);
355
356// Copies literals and returns the total literal length that was copied
357static u32 copy_literals(const size_t seq, istream_t *litstream,
358                         ostream_t *const out);
359
360// Given an offset code from a sequence command (either an actual offset value
361// or an index for previous offset), computes the correct offset and udpates
362// the offset history
363static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
364
365// Given an offset, match length, and total output, as well as the frame
366// context for the dictionary, determines if the dictionary is used and
367// executes the copy operation
368static void execute_match_copy(frame_context_t *const ctx, size_t offset,
369                              size_t match_length, size_t total_output,
370                              ostream_t *const out);
371
372/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
373
374size_t ZSTD_decompress(void *const dst, const size_t dst_len,
375                       const void *const src, const size_t src_len) {
376    dictionary_t* uninit_dict = create_dictionary();
377    size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
378                                                         src_len, uninit_dict);
379    free_dictionary(uninit_dict);
380    return decomp_size;
381}
382
383size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
384                                 const void *const src, const size_t src_len,
385                                 dictionary_t* parsed_dict) {
386
387    istream_t in = IO_make_istream(src, src_len);
388    ostream_t out = IO_make_ostream(dst, dst_len);
389
390    // "A content compressed by Zstandard is transformed into a Zstandard frame.
391    // Multiple frames can be appended into a single file or stream. A frame is
392    // totally independent, has a defined beginning and end, and a set of
393    // parameters which tells the decoder how to decompress it."
394
395    /* this decoder assumes decompression of a single frame */
396    decode_frame(&out, &in, parsed_dict);
397
398    return out.ptr - (u8 *)dst;
399}
400
401/******* FRAME DECODING ******************************************************/
402
403static void decode_data_frame(ostream_t *const out, istream_t *const in,
404                              const dictionary_t *const dict);
405static void init_frame_context(frame_context_t *const context,
406                               istream_t *const in,
407                               const dictionary_t *const dict);
408static void free_frame_context(frame_context_t *const context);
409static void parse_frame_header(frame_header_t *const header,
410                               istream_t *const in);
411static void frame_context_apply_dict(frame_context_t *const ctx,
412                                     const dictionary_t *const dict);
413
414static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
415                            istream_t *const in);
416
417static void decode_frame(ostream_t *const out, istream_t *const in,
418                         const dictionary_t *const dict) {
419    const u32 magic_number = IO_read_bits(in, 32);
420    // Zstandard frame
421    //
422    // "Magic_Number
423    //
424    // 4 Bytes, little-endian format. Value : 0xFD2FB528"
425    if (magic_number == 0xFD2FB528U) {
426        // ZSTD frame
427        decode_data_frame(out, in, dict);
428
429        return;
430    }
431
432    // not a real frame or a skippable frame
433    ERROR("Tried to decode non-ZSTD frame");
434}
435
436/// Decode a frame that contains compressed data.  Not all frames do as there
437/// are skippable frames.
438/// See
439/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
440static void decode_data_frame(ostream_t *const out, istream_t *const in,
441                              const dictionary_t *const dict) {
442    frame_context_t ctx;
443
444    // Initialize the context that needs to be carried from block to block
445    init_frame_context(&ctx, in, dict);
446
447    if (ctx.header.frame_content_size != 0 &&
448        ctx.header.frame_content_size > out->len) {
449        OUT_SIZE();
450    }
451
452    decompress_data(&ctx, out, in);
453
454    free_frame_context(&ctx);
455}
456
457/// Takes the information provided in the header and dictionary, and initializes
458/// the context for this frame
459static void init_frame_context(frame_context_t *const context,
460                               istream_t *const in,
461                               const dictionary_t *const dict) {
462    // Most fields in context are correct when initialized to 0
463    memset(context, 0, sizeof(frame_context_t));
464
465    // Parse data from the frame header
466    parse_frame_header(&context->header, in);
467
468    // Set up the offset history for the repeat offset commands
469    context->previous_offsets[0] = 1;
470    context->previous_offsets[1] = 4;
471    context->previous_offsets[2] = 8;
472
473    // Apply details from the dict if it exists
474    frame_context_apply_dict(context, dict);
475}
476
477static void free_frame_context(frame_context_t *const context) {
478    HUF_free_dtable(&context->literals_dtable);
479
480    FSE_free_dtable(&context->ll_dtable);
481    FSE_free_dtable(&context->ml_dtable);
482    FSE_free_dtable(&context->of_dtable);
483
484    memset(context, 0, sizeof(frame_context_t));
485}
486
487static void parse_frame_header(frame_header_t *const header,
488                               istream_t *const in) {
489    // "The first header's byte is called the Frame_Header_Descriptor. It tells
490    // which other fields are present. Decoding this byte is enough to tell the
491    // size of Frame_Header.
492    //
493    // Bit number   Field name
494    // 7-6  Frame_Content_Size_flag
495    // 5    Single_Segment_flag
496    // 4    Unused_bit
497    // 3    Reserved_bit
498    // 2    Content_Checksum_flag
499    // 1-0  Dictionary_ID_flag"
500    const u8 descriptor = IO_read_bits(in, 8);
501
502    // decode frame header descriptor into flags
503    const u8 frame_content_size_flag = descriptor >> 6;
504    const u8 single_segment_flag = (descriptor >> 5) & 1;
505    const u8 reserved_bit = (descriptor >> 3) & 1;
506    const u8 content_checksum_flag = (descriptor >> 2) & 1;
507    const u8 dictionary_id_flag = descriptor & 3;
508
509    if (reserved_bit != 0) {
510        CORRUPTION();
511    }
512
513    header->single_segment_flag = single_segment_flag;
514    header->content_checksum_flag = content_checksum_flag;
515
516    // decode window size
517    if (!single_segment_flag) {
518        // "Provides guarantees on maximum back-reference distance that will be
519        // used within compressed data. This information is important for
520        // decoders to allocate enough memory.
521        //
522        // Bit numbers  7-3         2-0
523        // Field name   Exponent    Mantissa"
524        u8 window_descriptor = IO_read_bits(in, 8);
525        u8 exponent = window_descriptor >> 3;
526        u8 mantissa = window_descriptor & 7;
527
528        // Use the algorithm from the specification to compute window size
529        // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
530        size_t window_base = (size_t)1 << (10 + exponent);
531        size_t window_add = (window_base / 8) * mantissa;
532        header->window_size = window_base + window_add;
533    }
534
535    // decode dictionary id if it exists
536    if (dictionary_id_flag) {
537        // "This is a variable size field, which contains the ID of the
538        // dictionary required to properly decode the frame. Note that this
539        // field is optional. When it's not present, it's up to the caller to
540        // make sure it uses the correct dictionary. Format is little-endian."
541        const int bytes_array[] = {0, 1, 2, 4};
542        const int bytes = bytes_array[dictionary_id_flag];
543
544        header->dictionary_id = IO_read_bits(in, bytes * 8);
545    } else {
546        header->dictionary_id = 0;
547    }
548
549    // decode frame content size if it exists
550    if (single_segment_flag || frame_content_size_flag) {
551        // "This is the original (uncompressed) size. This information is
552        // optional. The Field_Size is provided according to value of
553        // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
554        // present), 1, 2, 4 or 8 bytes. Format is little-endian."
555        //
556        // if frame_content_size_flag == 0 but single_segment_flag is set, we
557        // still have a 1 byte field
558        const int bytes_array[] = {1, 2, 4, 8};
559        const int bytes = bytes_array[frame_content_size_flag];
560
561        header->frame_content_size = IO_read_bits(in, bytes * 8);
562        if (bytes == 2) {
563            // "When Field_Size is 2, the offset of 256 is added."
564            header->frame_content_size += 256;
565        }
566    } else {
567        header->frame_content_size = 0;
568    }
569
570    if (single_segment_flag) {
571        // "The Window_Descriptor byte is optional. It is absent when
572        // Single_Segment_flag is set. In this case, the maximum back-reference
573        // distance is the content size itself, which can be any value from 1 to
574        // 2^64-1 bytes (16 EB)."
575        header->window_size = header->frame_content_size;
576    }
577}
578
579/// A dictionary acts as initializing values for the frame context before
580/// decompression, so we implement it by applying it's predetermined
581/// tables and content to the context before beginning decompression
582static void frame_context_apply_dict(frame_context_t *const ctx,
583                                     const dictionary_t *const dict) {
584    // If the content pointer is NULL then it must be an empty dict
585    if (!dict || !dict->content)
586        return;
587
588    // If the requested dictionary_id is non-zero, the correct dictionary must
589    // be present
590    if (ctx->header.dictionary_id != 0 &&
591        ctx->header.dictionary_id != dict->dictionary_id) {
592        ERROR("Wrong dictionary provided");
593    }
594
595    // Copy the dict content to the context for references during sequence
596    // execution
597    ctx->dict_content = dict->content;
598    ctx->dict_content_len = dict->content_size;
599
600    // If it's a formatted dict copy the precomputed tables in so they can
601    // be used in the table repeat modes
602    if (dict->dictionary_id != 0) {
603        // Deep copy the entropy tables so they can be freed independently of
604        // the dictionary struct
605        HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
606        FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
607        FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
608        FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
609
610        // Copy the repeated offsets
611        memcpy(ctx->previous_offsets, dict->previous_offsets,
612               sizeof(ctx->previous_offsets));
613    }
614}
615
616/// Decompress the data from a frame block by block
617static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
618                            istream_t *const in) {
619    // "A frame encapsulates one or multiple blocks. Each block can be
620    // compressed or not, and has a guaranteed maximum content size, which
621    // depends on frame parameters. Unlike frames, each block depends on
622    // previous blocks for proper decoding. However, each block can be
623    // decompressed without waiting for its successor, allowing streaming
624    // operations."
625    int last_block = 0;
626    do {
627        // "Last_Block
628        //
629        // The lowest bit signals if this block is the last one. Frame ends
630        // right after this block.
631        //
632        // Block_Type and Block_Size
633        //
634        // The next 2 bits represent the Block_Type, while the remaining 21 bits
635        // represent the Block_Size. Format is little-endian."
636        last_block = IO_read_bits(in, 1);
637        const int block_type = IO_read_bits(in, 2);
638        const size_t block_len = IO_read_bits(in, 21);
639
640        switch (block_type) {
641        case 0: {
642            // "Raw_Block - this is an uncompressed block. Block_Size is the
643            // number of bytes to read and copy."
644            const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
645            u8 *const write_ptr = IO_get_write_ptr(out, block_len);
646
647            // Copy the raw data into the output
648            memcpy(write_ptr, read_ptr, block_len);
649
650            ctx->current_total_output += block_len;
651            break;
652        }
653        case 1: {
654            // "RLE_Block - this is a single byte, repeated N times. In which
655            // case, Block_Size is the size to regenerate, while the
656            // "compressed" block is just 1 byte (the byte to repeat)."
657            const u8 *const read_ptr = IO_get_read_ptr(in, 1);
658            u8 *const write_ptr = IO_get_write_ptr(out, block_len);
659
660            // Copy `block_len` copies of `read_ptr[0]` to the output
661            memset(write_ptr, read_ptr[0], block_len);
662
663            ctx->current_total_output += block_len;
664            break;
665        }
666        case 2: {
667            // "Compressed_Block - this is a Zstandard compressed block,
668            // detailed in another section of this specification. Block_Size is
669            // the compressed size.
670
671            // Create a sub-stream for the block
672            istream_t block_stream = IO_make_sub_istream(in, block_len);
673            decompress_block(ctx, out, &block_stream);
674            break;
675        }
676        case 3:
677            // "Reserved - this is not a block. This value cannot be used with
678            // current version of this specification."
679            CORRUPTION();
680            break;
681        default:
682            IMPOSSIBLE();
683        }
684    } while (!last_block);
685
686    if (ctx->header.content_checksum_flag) {
687        // This program does not support checking the checksum, so skip over it
688        // if it's present
689        IO_advance_input(in, 4);
690    }
691}
692/******* END FRAME DECODING ***************************************************/
693
694/******* BLOCK DECOMPRESSION **************************************************/
695static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
696                             istream_t *const in) {
697    // "A compressed block consists of 2 sections :
698    //
699    // Literals_Section
700    // Sequences_Section"
701
702
703    // Part 1: decode the literals block
704    u8 *literals = NULL;
705    const size_t literals_size = decode_literals(ctx, in, &literals);
706
707    // Part 2: decode the sequences block
708    sequence_command_t *sequences = NULL;
709    const size_t num_sequences =
710        decode_sequences(ctx, in, &sequences);
711
712    // Part 3: combine literals and sequence commands to generate output
713    execute_sequences(ctx, out, literals, literals_size, sequences,
714                      num_sequences);
715    free(literals);
716    free(sequences);
717}
718/******* END BLOCK DECOMPRESSION **********************************************/
719
720/******* LITERALS DECODING ****************************************************/
721static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
722                                     const int block_type,
723                                     const int size_format);
724static size_t decode_literals_compressed(frame_context_t *const ctx,
725                                         istream_t *const in,
726                                         u8 **const literals,
727                                         const int block_type,
728                                         const int size_format);
729static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
730static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
731                                    int *const num_symbs);
732
733static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
734                              u8 **const literals) {
735    // "Literals can be stored uncompressed or compressed using Huffman prefix
736    // codes. When compressed, an optional tree description can be present,
737    // followed by 1 or 4 streams."
738    //
739    // "Literals_Section_Header
740    //
741    // Header is in charge of describing how literals are packed. It's a
742    // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
743    // little-endian convention."
744    //
745    // "Literals_Block_Type
746    //
747    // This field uses 2 lowest bits of first byte, describing 4 different block
748    // types"
749    //
750    // size_format takes between 1 and 2 bits
751    int block_type = IO_read_bits(in, 2);
752    int size_format = IO_read_bits(in, 2);
753
754    if (block_type <= 1) {
755        // Raw or RLE literals block
756        return decode_literals_simple(in, literals, block_type,
757                                      size_format);
758    } else {
759        // Huffman compressed literals
760        return decode_literals_compressed(ctx, in, literals, block_type,
761                                          size_format);
762    }
763}
764
765/// Decodes literals blocks in raw or RLE form
766static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
767                                     const int block_type,
768                                     const int size_format) {
769    size_t size;
770    switch (size_format) {
771    // These cases are in the form ?0
772    // In this case, the ? bit is actually part of the size field
773    case 0:
774    case 2:
775        // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
776        IO_rewind_bits(in, 1);
777        size = IO_read_bits(in, 5);
778        break;
779    case 1:
780        // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
781        size = IO_read_bits(in, 12);
782        break;
783    case 3:
784        // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
785        size = IO_read_bits(in, 20);
786        break;
787    default:
788        // Size format is in range 0-3
789        IMPOSSIBLE();
790    }
791
792    if (size > MAX_LITERALS_SIZE) {
793        CORRUPTION();
794    }
795
796    *literals = malloc(size);
797    if (!*literals) {
798        BAD_ALLOC();
799    }
800
801    switch (block_type) {
802    case 0: {
803        // "Raw_Literals_Block - Literals are stored uncompressed."
804        const u8 *const read_ptr = IO_get_read_ptr(in, size);
805        memcpy(*literals, read_ptr, size);
806        break;
807    }
808    case 1: {
809        // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
810        const u8 *const read_ptr = IO_get_read_ptr(in, 1);
811        memset(*literals, read_ptr[0], size);
812        break;
813    }
814    default:
815        IMPOSSIBLE();
816    }
817
818    return size;
819}
820
821/// Decodes Huffman compressed literals
822static size_t decode_literals_compressed(frame_context_t *const ctx,
823                                         istream_t *const in,
824                                         u8 **const literals,
825                                         const int block_type,
826                                         const int size_format) {
827    size_t regenerated_size, compressed_size;
828    // Only size_format=0 has 1 stream, so default to 4
829    int num_streams = 4;
830    switch (size_format) {
831    case 0:
832        // "A single stream. Both Compressed_Size and Regenerated_Size use 10
833        // bits (0-1023)."
834        num_streams = 1;
835    // Fall through as it has the same size format
836    case 1:
837        // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
838        // (0-1023)."
839        regenerated_size = IO_read_bits(in, 10);
840        compressed_size = IO_read_bits(in, 10);
841        break;
842    case 2:
843        // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
844        // (0-16383)."
845        regenerated_size = IO_read_bits(in, 14);
846        compressed_size = IO_read_bits(in, 14);
847        break;
848    case 3:
849        // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
850        // (0-262143)."
851        regenerated_size = IO_read_bits(in, 18);
852        compressed_size = IO_read_bits(in, 18);
853        break;
854    default:
855        // Impossible
856        IMPOSSIBLE();
857    }
858    if (regenerated_size > MAX_LITERALS_SIZE ||
859        compressed_size >= regenerated_size) {
860        CORRUPTION();
861    }
862
863    *literals = malloc(regenerated_size);
864    if (!*literals) {
865        BAD_ALLOC();
866    }
867
868    ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
869    istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
870
871    if (block_type == 2) {
872        // Decode the provided Huffman table
873        // "This section is only present when Literals_Block_Type type is
874        // Compressed_Literals_Block (2)."
875
876        HUF_free_dtable(&ctx->literals_dtable);
877        decode_huf_table(&ctx->literals_dtable, &huf_stream);
878    } else {
879        // If the previous Huffman table is being repeated, ensure it exists
880        if (!ctx->literals_dtable.symbols) {
881            CORRUPTION();
882        }
883    }
884
885    size_t symbols_decoded;
886    if (num_streams == 1) {
887        symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
888    } else {
889        symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
890    }
891
892    if (symbols_decoded != regenerated_size) {
893        CORRUPTION();
894    }
895
896    return regenerated_size;
897}
898
899// Decode the Huffman table description
900static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
901    // "All literal values from zero (included) to last present one (excluded)
902    // are represented by Weight with values from 0 to Max_Number_of_Bits."
903
904    // "This is a single byte value (0-255), which describes how to decode the list of weights."
905    const u8 header = IO_read_bits(in, 8);
906
907    u8 weights[HUF_MAX_SYMBS];
908    memset(weights, 0, sizeof(weights));
909
910    int num_symbs;
911
912    if (header >= 128) {
913        // "This is a direct representation, where each Weight is written
914        // directly as a 4 bits field (0-15). The full representation occupies
915        // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
916        // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
917        // 127"
918        num_symbs = header - 127;
919        const size_t bytes = (num_symbs + 1) / 2;
920
921        const u8 *const weight_src = IO_get_read_ptr(in, bytes);
922
923        for (int i = 0; i < num_symbs; i++) {
924            // "They are encoded forward, 2
925            // weights to a byte with the first weight taking the top four bits
926            // and the second taking the bottom four (e.g. the following
927            // operations could be used to read the weights: Weight[0] =
928            // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
929            if (i % 2 == 0) {
930                weights[i] = weight_src[i / 2] >> 4;
931            } else {
932                weights[i] = weight_src[i / 2] & 0xf;
933            }
934        }
935    } else {
936        // The weights are FSE encoded, decode them before we can construct the
937        // table
938        istream_t fse_stream = IO_make_sub_istream(in, header);
939        ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
940        fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
941    }
942
943    // Construct the table using the decoded weights
944    HUF_init_dtable_usingweights(dtable, weights, num_symbs);
945}
946
947static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
948                                    int *const num_symbs) {
949    const int MAX_ACCURACY_LOG = 7;
950
951    FSE_dtable dtable;
952
953    // "An FSE bitstream starts by a header, describing probabilities
954    // distribution. It will create a Decoding Table. For a list of Huffman
955    // weights, maximum accuracy is 7 bits."
956    FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
957
958    // Decode the weights
959    *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
960
961    FSE_free_dtable(&dtable);
962}
963/******* END LITERALS DECODING ************************************************/
964
965/******* SEQUENCE DECODING ****************************************************/
966/// The combination of FSE states needed to decode sequences
967typedef struct {
968    FSE_dtable ll_table;
969    FSE_dtable of_table;
970    FSE_dtable ml_table;
971
972    u16 ll_state;
973    u16 of_state;
974    u16 ml_state;
975} sequence_states_t;
976
977/// Different modes to signal to decode_seq_tables what to do
978typedef enum {
979    seq_literal_length = 0,
980    seq_offset = 1,
981    seq_match_length = 2,
982} seq_part_t;
983
984typedef enum {
985    seq_predefined = 0,
986    seq_rle = 1,
987    seq_fse = 2,
988    seq_repeat = 3,
989} seq_mode_t;
990
991/// The predefined FSE distribution tables for `seq_predefined` mode
992static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
993    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
994    2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
995static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
996    1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
997    1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
998static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
999    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
1000    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
1001    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
1002
1003/// The sequence decoding baseline and number of additional bits to read/add
1004/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
1005static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
1006    0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
1007    12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
1008    48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538};
1009static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
1010    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
1011    1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1012
1013static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
1014    3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
1015    17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
1016    31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
1017    99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
1018static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
1019    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
1020    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
1021    2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1022
1023/// Offset decoding is simpler so we just need a maximum code value
1024static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
1025
1026static void decompress_sequences(frame_context_t *const ctx,
1027                                 istream_t *const in,
1028                                 sequence_command_t *const sequences,
1029                                 const size_t num_sequences);
1030static sequence_command_t decode_sequence(sequence_states_t *const state,
1031                                          const u8 *const src,
1032                                          i64 *const offset);
1033static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1034                               const seq_part_t type, const seq_mode_t mode);
1035
1036static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1037                               sequence_command_t **const sequences) {
1038    // "A compressed block is a succession of sequences . A sequence is a
1039    // literal copy command, followed by a match copy command. A literal copy
1040    // command specifies a length. It is the number of bytes to be copied (or
1041    // extracted) from the literal section. A match copy command specifies an
1042    // offset and a length. The offset gives the position to copy from, which
1043    // can be within a previous block."
1044
1045    size_t num_sequences;
1046
1047    // "Number_of_Sequences
1048    //
1049    // This is a variable size field using between 1 and 3 bytes. Let's call its
1050    // first byte byte0."
1051    u8 header = IO_read_bits(in, 8);
1052    if (header == 0) {
1053        // "There are no sequences. The sequence section stops there.
1054        // Regenerated content is defined entirely by literals section."
1055        *sequences = NULL;
1056        return 0;
1057    } else if (header < 128) {
1058        // "Number_of_Sequences = byte0 . Uses 1 byte."
1059        num_sequences = header;
1060    } else if (header < 255) {
1061        // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1062        num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1063    } else {
1064        // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1065        num_sequences = IO_read_bits(in, 16) + 0x7F00;
1066    }
1067
1068    *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1069    if (!*sequences) {
1070        BAD_ALLOC();
1071    }
1072
1073    decompress_sequences(ctx, in, *sequences, num_sequences);
1074    return num_sequences;
1075}
1076
1077/// Decompress the FSE encoded sequence commands
1078static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1079                                 sequence_command_t *const sequences,
1080                                 const size_t num_sequences) {
1081    // "The Sequences_Section regroup all symbols required to decode commands.
1082    // There are 3 symbol types : literals lengths, offsets and match lengths.
1083    // They are encoded together, interleaved, in a single bitstream."
1084
1085    // "Symbol compression modes
1086    //
1087    // This is a single byte, defining the compression mode of each symbol
1088    // type."
1089    //
1090    // Bit number : Field name
1091    // 7-6        : Literals_Lengths_Mode
1092    // 5-4        : Offsets_Mode
1093    // 3-2        : Match_Lengths_Mode
1094    // 1-0        : Reserved
1095    u8 compression_modes = IO_read_bits(in, 8);
1096
1097    if ((compression_modes & 3) != 0) {
1098        // Reserved bits set
1099        CORRUPTION();
1100    }
1101
1102    // "Following the header, up to 3 distribution tables can be described. When
1103    // present, they are in this order :
1104    //
1105    // Literals lengths
1106    // Offsets
1107    // Match Lengths"
1108    // Update the tables we have stored in the context
1109    decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1110                     (compression_modes >> 6) & 3);
1111
1112    decode_seq_table(&ctx->of_dtable, in, seq_offset,
1113                     (compression_modes >> 4) & 3);
1114
1115    decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1116                     (compression_modes >> 2) & 3);
1117
1118
1119    sequence_states_t states;
1120
1121    // Initialize the decoding tables
1122    {
1123        states.ll_table = ctx->ll_dtable;
1124        states.of_table = ctx->of_dtable;
1125        states.ml_table = ctx->ml_dtable;
1126    }
1127
1128    const size_t len = IO_istream_len(in);
1129    const u8 *const src = IO_get_read_ptr(in, len);
1130
1131    // "After writing the last bit containing information, the compressor writes
1132    // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1133    const int padding = 8 - highest_set_bit(src[len - 1]);
1134    // The offset starts at the end because FSE streams are read backwards
1135    i64 bit_offset = len * 8 - padding;
1136
1137    // "The bitstream starts with initial state values, each using the required
1138    // number of bits in their respective accuracy, decoded previously from
1139    // their normalized distribution.
1140    //
1141    // It starts by Literals_Length_State, followed by Offset_State, and finally
1142    // Match_Length_State."
1143    FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1144    FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1145    FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1146
1147    for (size_t i = 0; i < num_sequences; i++) {
1148        // Decode sequences one by one
1149        sequences[i] = decode_sequence(&states, src, &bit_offset);
1150    }
1151
1152    if (bit_offset != 0) {
1153        CORRUPTION();
1154    }
1155}
1156
1157// Decode a single sequence and update the state
1158static sequence_command_t decode_sequence(sequence_states_t *const states,
1159                                          const u8 *const src,
1160                                          i64 *const offset) {
1161    // "Each symbol is a code in its own context, which specifies Baseline and
1162    // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1163    // additional bits in the same bitstream."
1164
1165    // Decode symbols, but don't update states
1166    const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1167    const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1168    const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1169
1170    // Offset doesn't need a max value as it's not decoded using a table
1171    if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1172        ml_code > SEQ_MAX_CODES[seq_match_length]) {
1173        CORRUPTION();
1174    }
1175
1176    // Read the interleaved bits
1177    sequence_command_t seq;
1178    // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1179    // It then does the same for Match_Length, and then for Literals_Length."
1180    seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1181
1182    seq.match_length =
1183        SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1184        STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1185
1186    seq.literal_length =
1187        SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1188        STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1189
1190    // "If it is not the last sequence in the block, the next operation is to
1191    // update states. Using the rules pre-calculated in the decoding tables,
1192    // Literals_Length_State is updated, followed by Match_Length_State, and
1193    // then Offset_State."
1194    // If the stream is complete don't read bits to update state
1195    if (*offset != 0) {
1196        FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1197        FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1198        FSE_update_state(&states->of_table, &states->of_state, src, offset);
1199    }
1200
1201    return seq;
1202}
1203
1204/// Given a sequence part and table mode, decode the FSE distribution
1205/// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1206static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1207                             const seq_part_t type, const seq_mode_t mode) {
1208    // Constant arrays indexed by seq_part_t
1209    const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1210                                                SEQ_OFFSET_DEFAULT_DIST,
1211                                                SEQ_MATCH_LENGTH_DEFAULT_DIST};
1212    const size_t default_distribution_lengths[] = {36, 29, 53};
1213    const size_t default_distribution_accuracies[] = {6, 5, 6};
1214
1215    const size_t max_accuracies[] = {9, 8, 9};
1216
1217    if (mode != seq_repeat) {
1218        // Free old one before overwriting
1219        FSE_free_dtable(table);
1220    }
1221
1222    switch (mode) {
1223    case seq_predefined: {
1224        // "Predefined_Mode : uses a predefined distribution table."
1225        const i16 *distribution = default_distributions[type];
1226        const size_t symbs = default_distribution_lengths[type];
1227        const size_t accuracy_log = default_distribution_accuracies[type];
1228
1229        FSE_init_dtable(table, distribution, symbs, accuracy_log);
1230        break;
1231    }
1232    case seq_rle: {
1233        // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1234        const u8 symb = IO_get_read_ptr(in, 1)[0];
1235        FSE_init_dtable_rle(table, symb);
1236        break;
1237    }
1238    case seq_fse: {
1239        // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1240        // will be present "
1241        FSE_decode_header(table, in, max_accuracies[type]);
1242        break;
1243    }
1244    case seq_repeat:
1245        // "Repeat_Mode : re-use distribution table from previous compressed
1246        // block."
1247        // Nothing to do here, table will be unchanged
1248        if (!table->symbols) {
1249            // This mode is invalid if we don't already have a table
1250            CORRUPTION();
1251        }
1252        break;
1253    default:
1254        // Impossible, as mode is from 0-3
1255        IMPOSSIBLE();
1256        break;
1257    }
1258
1259}
1260/******* END SEQUENCE DECODING ************************************************/
1261
1262/******* SEQUENCE EXECUTION ***************************************************/
1263static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1264                              const u8 *const literals,
1265                              const size_t literals_len,
1266                              const sequence_command_t *const sequences,
1267                              const size_t num_sequences) {
1268    istream_t litstream = IO_make_istream(literals, literals_len);
1269
1270    u64 *const offset_hist = ctx->previous_offsets;
1271    size_t total_output = ctx->current_total_output;
1272
1273    for (size_t i = 0; i < num_sequences; i++) {
1274        const sequence_command_t seq = sequences[i];
1275        {
1276            const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1277            total_output += literals_size;
1278        }
1279
1280        size_t const offset = compute_offset(seq, offset_hist);
1281
1282        size_t const match_length = seq.match_length;
1283
1284        execute_match_copy(ctx, offset, match_length, total_output, out);
1285
1286        total_output += match_length;
1287    }
1288
1289    // Copy any leftover literals
1290    {
1291        size_t len = IO_istream_len(&litstream);
1292        copy_literals(len, &litstream, out);
1293        total_output += len;
1294    }
1295
1296    ctx->current_total_output = total_output;
1297}
1298
1299static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1300                         ostream_t *const out) {
1301    // If the sequence asks for more literals than are left, the
1302    // sequence must be corrupted
1303    if (literal_length > IO_istream_len(litstream)) {
1304        CORRUPTION();
1305    }
1306
1307    u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1308    const u8 *const read_ptr =
1309         IO_get_read_ptr(litstream, literal_length);
1310    // Copy literals to output
1311    memcpy(write_ptr, read_ptr, literal_length);
1312
1313    return literal_length;
1314}
1315
1316static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1317    size_t offset;
1318    // Offsets are special, we need to handle the repeat offsets
1319    if (seq.offset <= 3) {
1320        // "The first 3 values define a repeated offset and we will call
1321        // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1322        // They are sorted in recency order, with Repeated_Offset1 meaning
1323        // 'most recent one'".
1324
1325        // Use 0 indexing for the array
1326        u32 idx = seq.offset - 1;
1327        if (seq.literal_length == 0) {
1328            // "There is an exception though, when current sequence's
1329            // literals length is 0. In this case, repeated offsets are
1330            // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1331            // Repeated_Offset2 becomes Repeated_Offset3, and
1332            // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1333            idx++;
1334        }
1335
1336        if (idx == 0) {
1337            offset = offset_hist[0];
1338        } else {
1339            // If idx == 3 then literal length was 0 and the offset was 3,
1340            // as per the exception listed above
1341            offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1342
1343            // If idx == 1 we don't need to modify offset_hist[2], since
1344            // we're using the second-most recent code
1345            if (idx > 1) {
1346                offset_hist[2] = offset_hist[1];
1347            }
1348            offset_hist[1] = offset_hist[0];
1349            offset_hist[0] = offset;
1350        }
1351    } else {
1352        // When it's not a repeat offset:
1353        // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1354        offset = seq.offset - 3;
1355
1356        // Shift back history
1357        offset_hist[2] = offset_hist[1];
1358        offset_hist[1] = offset_hist[0];
1359        offset_hist[0] = offset;
1360    }
1361    return offset;
1362}
1363
1364static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1365                              size_t match_length, size_t total_output,
1366                              ostream_t *const out) {
1367    u8 *write_ptr = IO_get_write_ptr(out, match_length);
1368    if (total_output <= ctx->header.window_size) {
1369        // In this case offset might go back into the dictionary
1370        if (offset > total_output + ctx->dict_content_len) {
1371            // The offset goes beyond even the dictionary
1372            CORRUPTION();
1373        }
1374
1375        if (offset > total_output) {
1376            // "The rest of the dictionary is its content. The content act
1377            // as a "past" in front of data to compress or decompress, so it
1378            // can be referenced in sequence commands."
1379            const size_t dict_copy =
1380                MIN(offset - total_output, match_length);
1381            const size_t dict_offset =
1382                ctx->dict_content_len - (offset - total_output);
1383
1384            memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1385            write_ptr += dict_copy;
1386            match_length -= dict_copy;
1387        }
1388    } else if (offset > ctx->header.window_size) {
1389        CORRUPTION();
1390    }
1391
1392    // We must copy byte by byte because the match length might be larger
1393    // than the offset
1394    // ex: if the output so far was "abc", a command with offset=3 and
1395    // match_length=6 would produce "abcabcabc" as the new output
1396    for (size_t j = 0; j < match_length; j++) {
1397        *write_ptr = *(write_ptr - offset);
1398        write_ptr++;
1399    }
1400}
1401/******* END SEQUENCE EXECUTION ***********************************************/
1402
1403/******* OUTPUT SIZE COUNTING *************************************************/
1404/// Get the decompressed size of an input stream so memory can be allocated in
1405/// advance.
1406/// This implementation assumes `src` points to a single ZSTD-compressed frame
1407size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1408    istream_t in = IO_make_istream(src, src_len);
1409
1410    // get decompressed size from ZSTD frame header
1411    {
1412        const u32 magic_number = IO_read_bits(&in, 32);
1413
1414        if (magic_number == 0xFD2FB528U) {
1415            // ZSTD frame
1416            frame_header_t header;
1417            parse_frame_header(&header, &in);
1418
1419            if (header.frame_content_size == 0 && !header.single_segment_flag) {
1420                // Content size not provided, we can't tell
1421                return -1;
1422            }
1423
1424            return header.frame_content_size;
1425        } else {
1426            // not a real frame or skippable frame
1427            ERROR("ZSTD frame magic number did not match");
1428        }
1429    }
1430}
1431/******* END OUTPUT SIZE COUNTING *********************************************/
1432
1433/******* DICTIONARY PARSING ***************************************************/
1434#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1435#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1436
1437dictionary_t* create_dictionary() {
1438    dictionary_t* dict = calloc(1, sizeof(dictionary_t));
1439    if (!dict) {
1440        BAD_ALLOC();
1441    }
1442    return dict;
1443}
1444
1445static void init_dictionary_content(dictionary_t *const dict,
1446                                    istream_t *const in);
1447
1448void parse_dictionary(dictionary_t *const dict, const void *src,
1449                             size_t src_len) {
1450    const u8 *byte_src = (const u8 *)src;
1451    memset(dict, 0, sizeof(dictionary_t));
1452    if (src == NULL) { /* cannot initialize dictionary with null src */
1453        NULL_SRC();
1454    }
1455    if (src_len < 8) {
1456        DICT_SIZE_ERROR();
1457    }
1458
1459    istream_t in = IO_make_istream(byte_src, src_len);
1460
1461    const u32 magic_number = IO_read_bits(&in, 32);
1462    if (magic_number != 0xEC30A437) {
1463        // raw content dict
1464        IO_rewind_bits(&in, 32);
1465        init_dictionary_content(dict, &in);
1466        return;
1467    }
1468
1469    dict->dictionary_id = IO_read_bits(&in, 32);
1470
1471    // "Entropy_Tables : following the same format as the tables in compressed
1472    // blocks. They are stored in following order : Huffman tables for literals,
1473    // FSE table for offsets, FSE table for match lengths, and FSE table for
1474    // literals lengths. It's finally followed by 3 offset values, populating
1475    // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1476    // little-endian each, for a total of 12 bytes. Each recent offset must have
1477    // a value < dictionary size."
1478    decode_huf_table(&dict->literals_dtable, &in);
1479    decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1480    decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1481    decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1482
1483    // Read in the previous offset history
1484    dict->previous_offsets[0] = IO_read_bits(&in, 32);
1485    dict->previous_offsets[1] = IO_read_bits(&in, 32);
1486    dict->previous_offsets[2] = IO_read_bits(&in, 32);
1487
1488    // Ensure the provided offsets aren't too large
1489    // "Each recent offset must have a value < dictionary size."
1490    for (int i = 0; i < 3; i++) {
1491        if (dict->previous_offsets[i] > src_len) {
1492            ERROR("Dictionary corrupted");
1493        }
1494    }
1495
1496    // "Content : The rest of the dictionary is its content. The content act as
1497    // a "past" in front of data to compress or decompress, so it can be
1498    // referenced in sequence commands."
1499    init_dictionary_content(dict, &in);
1500}
1501
1502static void init_dictionary_content(dictionary_t *const dict,
1503                                    istream_t *const in) {
1504    // Copy in the content
1505    dict->content_size = IO_istream_len(in);
1506    dict->content = malloc(dict->content_size);
1507    if (!dict->content) {
1508        BAD_ALLOC();
1509    }
1510
1511    const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1512
1513    memcpy(dict->content, content, dict->content_size);
1514}
1515
1516/// Free an allocated dictionary
1517void free_dictionary(dictionary_t *const dict) {
1518    HUF_free_dtable(&dict->literals_dtable);
1519    FSE_free_dtable(&dict->ll_dtable);
1520    FSE_free_dtable(&dict->of_dtable);
1521    FSE_free_dtable(&dict->ml_dtable);
1522
1523    free(dict->content);
1524
1525    memset(dict, 0, sizeof(dictionary_t));
1526
1527    free(dict);
1528}
1529/******* END DICTIONARY PARSING ***********************************************/
1530
1531/******* IO STREAM OPERATIONS *************************************************/
1532#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
1533/// Reads `num` bits from a bitstream, and updates the internal offset
1534static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1535    if (num_bits > 64 || num_bits <= 0) {
1536        ERROR("Attempt to read an invalid number of bits");
1537    }
1538
1539    const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1540    const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1541    if (bytes > in->len) {
1542        INP_SIZE();
1543    }
1544
1545    const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1546
1547    in->bit_offset = (num_bits + in->bit_offset) % 8;
1548    in->ptr += full_bytes;
1549    in->len -= full_bytes;
1550
1551    return result;
1552}
1553
1554/// If a non-zero number of bits have been read from the current byte, advance
1555/// the offset to the next byte
1556static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1557    if (num_bits < 0) {
1558        ERROR("Attempting to rewind stream by a negative number of bits");
1559    }
1560
1561    // move the offset back by `num_bits` bits
1562    const int new_offset = in->bit_offset - num_bits;
1563    // determine the number of whole bytes we have to rewind, rounding up to an
1564    // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1565    const i64 bytes = -(new_offset - 7) / 8;
1566
1567    in->ptr -= bytes;
1568    in->len += bytes;
1569    // make sure the resulting `bit_offset` is positive, as mod in C does not
1570    // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1571    in->bit_offset = ((new_offset % 8) + 8) % 8;
1572}
1573
1574/// If the remaining bits in a byte will be unused, advance to the end of the
1575/// byte
1576static inline void IO_align_stream(istream_t *const in) {
1577    if (in->bit_offset != 0) {
1578        if (in->len == 0) {
1579            INP_SIZE();
1580        }
1581        in->ptr++;
1582        in->len--;
1583        in->bit_offset = 0;
1584    }
1585}
1586
1587/// Write the given byte into the output stream
1588static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1589    if (out->len == 0) {
1590        OUT_SIZE();
1591    }
1592
1593    out->ptr[0] = symb;
1594    out->ptr++;
1595    out->len--;
1596}
1597
1598/// Returns the number of bytes left to be read in this stream.  The stream must
1599/// be byte aligned.
1600static inline size_t IO_istream_len(const istream_t *const in) {
1601    return in->len;
1602}
1603
1604/// Returns a pointer where `len` bytes can be read, and advances the internal
1605/// state.  The stream must be byte aligned.
1606static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1607    if (len > in->len) {
1608        INP_SIZE();
1609    }
1610    if (in->bit_offset != 0) {
1611        UNALIGNED();
1612    }
1613    const u8 *const ptr = in->ptr;
1614    in->ptr += len;
1615    in->len -= len;
1616
1617    return ptr;
1618}
1619/// Returns a pointer to write `len` bytes to, and advances the internal state
1620static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1621    if (len > out->len) {
1622        OUT_SIZE();
1623    }
1624    u8 *const ptr = out->ptr;
1625    out->ptr += len;
1626    out->len -= len;
1627
1628    return ptr;
1629}
1630
1631/// Advance the inner state by `len` bytes
1632static inline void IO_advance_input(istream_t *const in, size_t len) {
1633    if (len > in->len) {
1634         INP_SIZE();
1635    }
1636    if (in->bit_offset != 0) {
1637        UNALIGNED();
1638    }
1639
1640    in->ptr += len;
1641    in->len -= len;
1642}
1643
1644/// Returns an `ostream_t` constructed from the given pointer and length
1645static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1646    return (ostream_t) { out, len };
1647}
1648
1649/// Returns an `istream_t` constructed from the given pointer and length
1650static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1651    return (istream_t) { in, len, 0 };
1652}
1653
1654/// Returns an `istream_t` with the same base as `in`, and length `len`
1655/// Then, advance `in` to account for the consumed bytes
1656/// `in` must be byte aligned
1657static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1658    // Consume `len` bytes of the parent stream
1659    const u8 *const ptr = IO_get_read_ptr(in, len);
1660
1661    // Make a substream using the pointer to those `len` bytes
1662    return IO_make_istream(ptr, len);
1663}
1664/******* END IO STREAM OPERATIONS *********************************************/
1665
1666/******* BITSTREAM OPERATIONS *************************************************/
1667/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1668static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1669                               const size_t offset) {
1670    if (num_bits > 64) {
1671        ERROR("Attempt to read an invalid number of bits");
1672    }
1673
1674    // Skip over bytes that aren't in range
1675    src += offset / 8;
1676    size_t bit_offset = offset % 8;
1677    u64 res = 0;
1678
1679    int shift = 0;
1680    int left = num_bits;
1681    while (left > 0) {
1682        u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1683        // Read the next byte, shift it to account for the offset, and then mask
1684        // out the top part if we don't need all the bits
1685        res += (((u64)*src++ >> bit_offset) & mask) << shift;
1686        shift += 8 - bit_offset;
1687        left -= 8 - bit_offset;
1688        bit_offset = 0;
1689    }
1690
1691    return res;
1692}
1693
1694/// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
1695/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1696/// `src + offset`.  If the offset becomes negative, the extra bits at the
1697/// bottom are filled in with `0` bits instead of reading from before `src`.
1698static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1699                                   i64 *const offset) {
1700    *offset = *offset - bits;
1701    size_t actual_off = *offset;
1702    size_t actual_bits = bits;
1703    // Don't actually read bits from before the start of src, so if `*offset <
1704    // 0` fix actual_off and actual_bits to reflect the quantity to read
1705    if (*offset < 0) {
1706        actual_bits += *offset;
1707        actual_off = 0;
1708    }
1709    u64 res = read_bits_LE(src, actual_bits, actual_off);
1710
1711    if (*offset < 0) {
1712        // Fill in the bottom "overflowed" bits with 0's
1713        res = -*offset >= 64 ? 0 : (res << -*offset);
1714    }
1715    return res;
1716}
1717/******* END BITSTREAM OPERATIONS *********************************************/
1718
1719/******* BIT COUNTING OPERATIONS **********************************************/
1720/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1721/// `num`, or `-1` if `num == 0`.
1722static inline int highest_set_bit(const u64 num) {
1723    for (int i = 63; i >= 0; i--) {
1724        if (((u64)1 << i) <= num) {
1725            return i;
1726        }
1727    }
1728    return -1;
1729}
1730/******* END BIT COUNTING OPERATIONS ******************************************/
1731
1732/******* HUFFMAN PRIMITIVES ***************************************************/
1733static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1734                                   u16 *const state, const u8 *const src,
1735                                   i64 *const offset) {
1736    // Look up the symbol and number of bits to read
1737    const u8 symb = dtable->symbols[*state];
1738    const u8 bits = dtable->num_bits[*state];
1739    const u16 rest = STREAM_read_bits(src, bits, offset);
1740    // Shift `bits` bits out of the state, keeping the low order bits that
1741    // weren't necessary to determine this symbol.  Then add in the new bits
1742    // read from the stream.
1743    *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1744
1745    return symb;
1746}
1747
1748static inline void HUF_init_state(const HUF_dtable *const dtable,
1749                                  u16 *const state, const u8 *const src,
1750                                  i64 *const offset) {
1751    // Read in a full `dtable->max_bits` bits to initialize the state
1752    const u8 bits = dtable->max_bits;
1753    *state = STREAM_read_bits(src, bits, offset);
1754}
1755
1756static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1757                                     ostream_t *const out,
1758                                     istream_t *const in) {
1759    const size_t len = IO_istream_len(in);
1760    if (len == 0) {
1761        INP_SIZE();
1762    }
1763    const u8 *const src = IO_get_read_ptr(in, len);
1764
1765    // "Each bitstream must be read backward, that is starting from the end down
1766    // to the beginning. Therefore it's necessary to know the size of each
1767    // bitstream.
1768    //
1769    // It's also necessary to know exactly which bit is the latest. This is
1770    // detected by a final bit flag : the highest bit of latest byte is a
1771    // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1772    // final-bit-flag itself is not part of the useful bitstream. Hence, the
1773    // last byte contains between 0 and 7 useful bits."
1774    const int padding = 8 - highest_set_bit(src[len - 1]);
1775
1776    // Offset starts at the end because HUF streams are read backwards
1777    i64 bit_offset = len * 8 - padding;
1778    u16 state;
1779
1780    HUF_init_state(dtable, &state, src, &bit_offset);
1781
1782    size_t symbols_written = 0;
1783    while (bit_offset > -dtable->max_bits) {
1784        // Iterate over the stream, decoding one symbol at a time
1785        IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1786        symbols_written++;
1787    }
1788    // "The process continues up to reading the required number of symbols per
1789    // stream. If a bitstream is not entirely and exactly consumed, hence
1790    // reaching exactly its beginning position with all bits consumed, the
1791    // decoding process is considered faulty."
1792
1793    // When all symbols have been decoded, the final state value shouldn't have
1794    // any data from the stream, so it should have "read" dtable->max_bits from
1795    // before the start of `src`
1796    // Therefore `offset`, the edge to start reading new bits at, should be
1797    // dtable->max_bits before the start of the stream
1798    if (bit_offset != -dtable->max_bits) {
1799        CORRUPTION();
1800    }
1801
1802    return symbols_written;
1803}
1804
1805static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1806                                     ostream_t *const out, istream_t *const in) {
1807    // "Compressed size is provided explicitly : in the 4-streams variant,
1808    // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1809    // value represents the compressed size of one stream, in order. The last
1810    // stream size is deducted from total compressed size and from previously
1811    // decoded stream sizes"
1812    const size_t csize1 = IO_read_bits(in, 16);
1813    const size_t csize2 = IO_read_bits(in, 16);
1814    const size_t csize3 = IO_read_bits(in, 16);
1815
1816    istream_t in1 = IO_make_sub_istream(in, csize1);
1817    istream_t in2 = IO_make_sub_istream(in, csize2);
1818    istream_t in3 = IO_make_sub_istream(in, csize3);
1819    istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1820
1821    size_t total_output = 0;
1822    // Decode each stream independently for simplicity
1823    // If we wanted to we could decode all 4 at the same time for speed,
1824    // utilizing more execution units
1825    total_output += HUF_decompress_1stream(dtable, out, &in1);
1826    total_output += HUF_decompress_1stream(dtable, out, &in2);
1827    total_output += HUF_decompress_1stream(dtable, out, &in3);
1828    total_output += HUF_decompress_1stream(dtable, out, &in4);
1829
1830    return total_output;
1831}
1832
1833/// Initializes a Huffman table using canonical Huffman codes
1834/// For more explanation on canonical Huffman codes see
1835/// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1836/// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1837/// earlier codes)
1838static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1839                            const int num_symbs) {
1840    memset(table, 0, sizeof(HUF_dtable));
1841    if (num_symbs > HUF_MAX_SYMBS) {
1842        ERROR("Too many symbols for Huffman");
1843    }
1844
1845    u8 max_bits = 0;
1846    u16 rank_count[HUF_MAX_BITS + 1];
1847    memset(rank_count, 0, sizeof(rank_count));
1848
1849    // Count the number of symbols for each number of bits, and determine the
1850    // depth of the tree
1851    for (int i = 0; i < num_symbs; i++) {
1852        if (bits[i] > HUF_MAX_BITS) {
1853            ERROR("Huffman table depth too large");
1854        }
1855        max_bits = MAX(max_bits, bits[i]);
1856        rank_count[bits[i]]++;
1857    }
1858
1859    const size_t table_size = 1 << max_bits;
1860    table->max_bits = max_bits;
1861    table->symbols = malloc(table_size);
1862    table->num_bits = malloc(table_size);
1863
1864    if (!table->symbols || !table->num_bits) {
1865        free(table->symbols);
1866        free(table->num_bits);
1867        BAD_ALLOC();
1868    }
1869
1870    // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1871    // order. Symbols with a Weight of zero are removed. Then, starting from
1872    // lowest weight, prefix codes are distributed in order."
1873
1874    u32 rank_idx[HUF_MAX_BITS + 1];
1875    // Initialize the starting codes for each rank (number of bits)
1876    rank_idx[max_bits] = 0;
1877    for (int i = max_bits; i >= 1; i--) {
1878        rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1879        // The entire range takes the same number of bits so we can memset it
1880        memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1881    }
1882
1883    if (rank_idx[0] != table_size) {
1884        CORRUPTION();
1885    }
1886
1887    // Allocate codes and fill in the table
1888    for (int i = 0; i < num_symbs; i++) {
1889        if (bits[i] != 0) {
1890            // Allocate a code for this symbol and set its range in the table
1891            const u16 code = rank_idx[bits[i]];
1892            // Since the code doesn't care about the bottom `max_bits - bits[i]`
1893            // bits of state, it gets a range that spans all possible values of
1894            // the lower bits
1895            const u16 len = 1 << (max_bits - bits[i]);
1896            memset(&table->symbols[code], i, len);
1897            rank_idx[bits[i]] += len;
1898        }
1899    }
1900}
1901
1902static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1903                                         const u8 *const weights,
1904                                         const int num_symbs) {
1905    // +1 because the last weight is not transmitted in the header
1906    if (num_symbs + 1 > HUF_MAX_SYMBS) {
1907        ERROR("Too many symbols for Huffman");
1908    }
1909
1910    u8 bits[HUF_MAX_SYMBS];
1911
1912    u64 weight_sum = 0;
1913    for (int i = 0; i < num_symbs; i++) {
1914        // Weights are in the same range as bit count
1915        if (weights[i] > HUF_MAX_BITS) {
1916            CORRUPTION();
1917        }
1918        weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1919    }
1920
1921    // Find the first power of 2 larger than the sum
1922    const int max_bits = highest_set_bit(weight_sum) + 1;
1923    const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1924    // If the left over isn't a power of 2, the weights are invalid
1925    if (left_over & (left_over - 1)) {
1926        CORRUPTION();
1927    }
1928
1929    // left_over is used to find the last weight as it's not transmitted
1930    // by inverting 2^(weight - 1) we can determine the value of last_weight
1931    const int last_weight = highest_set_bit(left_over) + 1;
1932
1933    for (int i = 0; i < num_symbs; i++) {
1934        // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1935        bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1936    }
1937    bits[num_symbs] =
1938        max_bits + 1 - last_weight; // Last weight is always non-zero
1939
1940    HUF_init_dtable(table, bits, num_symbs + 1);
1941}
1942
1943static void HUF_free_dtable(HUF_dtable *const dtable) {
1944    free(dtable->symbols);
1945    free(dtable->num_bits);
1946    memset(dtable, 0, sizeof(HUF_dtable));
1947}
1948
1949static void HUF_copy_dtable(HUF_dtable *const dst,
1950                            const HUF_dtable *const src) {
1951    if (src->max_bits == 0) {
1952        memset(dst, 0, sizeof(HUF_dtable));
1953        return;
1954    }
1955
1956    const size_t size = (size_t)1 << src->max_bits;
1957    dst->max_bits = src->max_bits;
1958
1959    dst->symbols = malloc(size);
1960    dst->num_bits = malloc(size);
1961    if (!dst->symbols || !dst->num_bits) {
1962        BAD_ALLOC();
1963    }
1964
1965    memcpy(dst->symbols, src->symbols, size);
1966    memcpy(dst->num_bits, src->num_bits, size);
1967}
1968/******* END HUFFMAN PRIMITIVES ***********************************************/
1969
1970/******* FSE PRIMITIVES *******************************************************/
1971/// For more description of FSE see
1972/// https://github.com/Cyan4973/FiniteStateEntropy/
1973
1974/// Allow a symbol to be decoded without updating state
1975static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
1976                                 const u16 state) {
1977    return dtable->symbols[state];
1978}
1979
1980/// Consumes bits from the input and uses the current state to determine the
1981/// next state
1982static inline void FSE_update_state(const FSE_dtable *const dtable,
1983                                    u16 *const state, const u8 *const src,
1984                                    i64 *const offset) {
1985    const u8 bits = dtable->num_bits[*state];
1986    const u16 rest = STREAM_read_bits(src, bits, offset);
1987    *state = dtable->new_state_base[*state] + rest;
1988}
1989
1990/// Decodes a single FSE symbol and updates the offset
1991static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
1992                                   u16 *const state, const u8 *const src,
1993                                   i64 *const offset) {
1994    const u8 symb = FSE_peek_symbol(dtable, *state);
1995    FSE_update_state(dtable, state, src, offset);
1996    return symb;
1997}
1998
1999static inline void FSE_init_state(const FSE_dtable *const dtable,
2000                                  u16 *const state, const u8 *const src,
2001                                  i64 *const offset) {
2002    // Read in a full `accuracy_log` bits to initialize the state
2003    const u8 bits = dtable->accuracy_log;
2004    *state = STREAM_read_bits(src, bits, offset);
2005}
2006
2007static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2008                                          ostream_t *const out,
2009                                          istream_t *const in) {
2010    const size_t len = IO_istream_len(in);
2011    if (len == 0) {
2012        INP_SIZE();
2013    }
2014    const u8 *const src = IO_get_read_ptr(in, len);
2015
2016    // "Each bitstream must be read backward, that is starting from the end down
2017    // to the beginning. Therefore it's necessary to know the size of each
2018    // bitstream.
2019    //
2020    // It's also necessary to know exactly which bit is the latest. This is
2021    // detected by a final bit flag : the highest bit of latest byte is a
2022    // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2023    // final-bit-flag itself is not part of the useful bitstream. Hence, the
2024    // last byte contains between 0 and 7 useful bits."
2025    const int padding = 8 - highest_set_bit(src[len - 1]);
2026    i64 offset = len * 8 - padding;
2027
2028    u16 state1, state2;
2029    // "The first state (State1) encodes the even indexed symbols, and the
2030    // second (State2) encodes the odd indexes. State1 is initialized first, and
2031    // then State2, and they take turns decoding a single symbol and updating
2032    // their state."
2033    FSE_init_state(dtable, &state1, src, &offset);
2034    FSE_init_state(dtable, &state2, src, &offset);
2035
2036    // Decode until we overflow the stream
2037    // Since we decode in reverse order, overflowing the stream is offset going
2038    // negative
2039    size_t symbols_written = 0;
2040    while (1) {
2041        // "The number of symbols to decode is determined by tracking bitStream
2042        // overflow condition: If updating state after decoding a symbol would
2043        // require more bits than remain in the stream, it is assumed the extra
2044        // bits are 0. Then, the symbols for each of the final states are
2045        // decoded and the process is complete."
2046        IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2047        symbols_written++;
2048        if (offset < 0) {
2049            // There's still a symbol to decode in state2
2050            IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2051            symbols_written++;
2052            break;
2053        }
2054
2055        IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2056        symbols_written++;
2057        if (offset < 0) {
2058            // There's still a symbol to decode in state1
2059            IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2060            symbols_written++;
2061            break;
2062        }
2063    }
2064
2065    return symbols_written;
2066}
2067
2068static void FSE_init_dtable(FSE_dtable *const dtable,
2069                            const i16 *const norm_freqs, const int num_symbs,
2070                            const int accuracy_log) {
2071    if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2072        ERROR("FSE accuracy too large");
2073    }
2074    if (num_symbs > FSE_MAX_SYMBS) {
2075        ERROR("Too many symbols for FSE");
2076    }
2077
2078    dtable->accuracy_log = accuracy_log;
2079
2080    const size_t size = (size_t)1 << accuracy_log;
2081    dtable->symbols = malloc(size * sizeof(u8));
2082    dtable->num_bits = malloc(size * sizeof(u8));
2083    dtable->new_state_base = malloc(size * sizeof(u16));
2084
2085    if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2086        BAD_ALLOC();
2087    }
2088
2089    // Used to determine how many bits need to be read for each state,
2090    // and where the destination range should start
2091    // Needs to be u16 because max value is 2 * max number of symbols,
2092    // which can be larger than a byte can store
2093    u16 state_desc[FSE_MAX_SYMBS];
2094
2095    // "Symbols are scanned in their natural order for "less than 1"
2096    // probabilities. Symbols with this probability are being attributed a
2097    // single cell, starting from the end of the table. These symbols define a
2098    // full state reset, reading Accuracy_Log bits."
2099    int high_threshold = size;
2100    for (int s = 0; s < num_symbs; s++) {
2101        // Scan for low probability symbols to put at the top
2102        if (norm_freqs[s] == -1) {
2103            dtable->symbols[--high_threshold] = s;
2104            state_desc[s] = 1;
2105        }
2106    }
2107
2108    // "All remaining symbols are sorted in their natural order. Starting from
2109    // symbol 0 and table position 0, each symbol gets attributed as many cells
2110    // as its probability. Cell allocation is spreaded, not linear."
2111    // Place the rest in the table
2112    const u16 step = (size >> 1) + (size >> 3) + 3;
2113    const u16 mask = size - 1;
2114    u16 pos = 0;
2115    for (int s = 0; s < num_symbs; s++) {
2116        if (norm_freqs[s] <= 0) {
2117            continue;
2118        }
2119
2120        state_desc[s] = norm_freqs[s];
2121
2122        for (int i = 0; i < norm_freqs[s]; i++) {
2123            // Give `norm_freqs[s]` states to symbol s
2124            dtable->symbols[pos] = s;
2125            // "A position is skipped if already occupied, typically by a "less
2126            // than 1" probability symbol."
2127            do {
2128                pos = (pos + step) & mask;
2129            } while (pos >=
2130                     high_threshold);
2131            // Note: no other collision checking is necessary as `step` is
2132            // coprime to `size`, so the cycle will visit each position exactly
2133            // once
2134        }
2135    }
2136    if (pos != 0) {
2137        CORRUPTION();
2138    }
2139
2140    // Now we can fill baseline and num bits
2141    for (size_t i = 0; i < size; i++) {
2142        u8 symbol = dtable->symbols[i];
2143        u16 next_state_desc = state_desc[symbol]++;
2144        // Fills in the table appropriately, next_state_desc increases by symbol
2145        // over time, decreasing number of bits
2146        dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2147        // Baseline increases until the bit threshold is passed, at which point
2148        // it resets to 0
2149        dtable->new_state_base[i] =
2150            ((u16)next_state_desc << dtable->num_bits[i]) - size;
2151    }
2152}
2153
2154/// Decode an FSE header as defined in the Zstandard format specification and
2155/// use the decoded frequencies to initialize a decoding table.
2156static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2157                                const int max_accuracy_log) {
2158    // "An FSE distribution table describes the probabilities of all symbols
2159    // from 0 to the last present one (included) on a normalized scale of 1 <<
2160    // Accuracy_Log .
2161    //
2162    // It's a bitstream which is read forward, in little-endian fashion. It's
2163    // not necessary to know its exact size, since it will be discovered and
2164    // reported by the decoding process.
2165    if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2166        ERROR("FSE accuracy too large");
2167    }
2168
2169    // The bitstream starts by reporting on which scale it operates.
2170    // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2171    // and match lengths is 9, and for offsets is 8. Higher values are
2172    // considered errors."
2173    const int accuracy_log = 5 + IO_read_bits(in, 4);
2174    if (accuracy_log > max_accuracy_log) {
2175        ERROR("FSE accuracy too large");
2176    }
2177
2178    // "Then follows each symbol value, from 0 to last present one. The number
2179    // of bits used by each field is variable. It depends on :
2180    //
2181    // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2182    // and presuming 100 probabilities points have already been distributed, the
2183    // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2184    // Therefore, it must read log2sup(156) == 8 bits.
2185    //
2186    // Value decoded : small values use 1 less bit : example : Presuming values
2187    // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2188    // in an 8-bits field. They are used this way : first 99 values (hence from
2189    // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2190
2191    i32 remaining = 1 << accuracy_log;
2192    i16 frequencies[FSE_MAX_SYMBS];
2193
2194    int symb = 0;
2195    while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2196        // Log of the number of possible values we could read
2197        int bits = highest_set_bit(remaining + 1) + 1;
2198
2199        u16 val = IO_read_bits(in, bits);
2200
2201        // Try to mask out the lower bits to see if it qualifies for the "small
2202        // value" threshold
2203        const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2204        const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2205
2206        if ((val & lower_mask) < threshold) {
2207            IO_rewind_bits(in, 1);
2208            val = val & lower_mask;
2209        } else if (val > lower_mask) {
2210            val = val - threshold;
2211        }
2212
2213        // "Probability is obtained from Value decoded by following formula :
2214        // Proba = value - 1"
2215        const i16 proba = (i16)val - 1;
2216
2217        // "It means value 0 becomes negative probability -1. -1 is a special
2218        // probability, which means "less than 1". Its effect on distribution
2219        // table is described in next paragraph. For the purpose of calculating
2220        // cumulated distribution, it counts as one."
2221        remaining -= proba < 0 ? -proba : proba;
2222
2223        frequencies[symb] = proba;
2224        symb++;
2225
2226        // "When a symbol has a probability of zero, it is followed by a 2-bits
2227        // repeat flag. This repeat flag tells how many probabilities of zeroes
2228        // follow the current one. It provides a number ranging from 0 to 3. If
2229        // it is a 3, another 2-bits repeat flag follows, and so on."
2230        if (proba == 0) {
2231            // Read the next two bits to see how many more 0s
2232            int repeat = IO_read_bits(in, 2);
2233
2234            while (1) {
2235                for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2236                    frequencies[symb++] = 0;
2237                }
2238                if (repeat == 3) {
2239                    repeat = IO_read_bits(in, 2);
2240                } else {
2241                    break;
2242                }
2243            }
2244        }
2245    }
2246    IO_align_stream(in);
2247
2248    // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2249    // is complete. If the last symbol makes cumulated total go above 1 <<
2250    // Accuracy_Log, distribution is considered corrupted."
2251    if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2252        CORRUPTION();
2253    }
2254
2255    // Initialize the decoding table using the determined weights
2256    FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2257}
2258
2259static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2260    dtable->symbols = malloc(sizeof(u8));
2261    dtable->num_bits = malloc(sizeof(u8));
2262    dtable->new_state_base = malloc(sizeof(u16));
2263
2264    if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2265        BAD_ALLOC();
2266    }
2267
2268    // This setup will always have a state of 0, always return symbol `symb`,
2269    // and never consume any bits
2270    dtable->symbols[0] = symb;
2271    dtable->num_bits[0] = 0;
2272    dtable->new_state_base[0] = 0;
2273    dtable->accuracy_log = 0;
2274}
2275
2276static void FSE_free_dtable(FSE_dtable *const dtable) {
2277    free(dtable->symbols);
2278    free(dtable->num_bits);
2279    free(dtable->new_state_base);
2280    memset(dtable, 0, sizeof(FSE_dtable));
2281}
2282
2283static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
2284    if (src->accuracy_log == 0) {
2285        memset(dst, 0, sizeof(FSE_dtable));
2286        return;
2287    }
2288
2289    size_t size = (size_t)1 << src->accuracy_log;
2290    dst->accuracy_log = src->accuracy_log;
2291
2292    dst->symbols = malloc(size);
2293    dst->num_bits = malloc(size);
2294    dst->new_state_base = malloc(size * sizeof(u16));
2295    if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
2296        BAD_ALLOC();
2297    }
2298
2299    memcpy(dst->symbols, src->symbols, size);
2300    memcpy(dst->num_bits, src->num_bits, size);
2301    memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
2302}
2303/******* END FSE PRIMITIVES ***************************************************/
2304