1/* infback.c -- inflate using a call-back interface
2 * Copyright (C) 1995-2022 Mark Adler
3 * For conditions of distribution and use, see copyright notice in zlib.h
4 */
5
6/*
7   This code is largely copied from inflate.c.  Normally either infback.o or
8   inflate.o would be linked into an application--not both.  The interface
9   with inffast.c is retained so that optimized assembler-coded versions of
10   inflate_fast() can be used with either inflate.c or infback.c.
11 */
12
13#include "zutil.h"
14#include "inftrees.h"
15#include "inflate.h"
16#include "inffast.h"
17
18/*
19   strm provides memory allocation functions in zalloc and zfree, or
20   Z_NULL to use the library memory allocation functions.
21
22   windowBits is in the range 8..15, and window is a user-supplied
23   window and output buffer that is 2**windowBits bytes.
24 */
25int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
26                             unsigned char FAR *window, const char *version,
27                             int stream_size) {
28    struct inflate_state FAR *state;
29
30    if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
31        stream_size != (int)(sizeof(z_stream)))
32        return Z_VERSION_ERROR;
33    if (strm == Z_NULL || window == Z_NULL ||
34        windowBits < 8 || windowBits > 15)
35        return Z_STREAM_ERROR;
36    strm->msg = Z_NULL;                 /* in case we return an error */
37    if (strm->zalloc == (alloc_func)0) {
38#ifdef Z_SOLO
39        return Z_STREAM_ERROR;
40#else
41        strm->zalloc = zcalloc;
42        strm->opaque = (voidpf)0;
43#endif
44    }
45    if (strm->zfree == (free_func)0)
46#ifdef Z_SOLO
47        return Z_STREAM_ERROR;
48#else
49    strm->zfree = zcfree;
50#endif
51    state = (struct inflate_state FAR *)ZALLOC(strm, 1,
52                                               sizeof(struct inflate_state));
53    if (state == Z_NULL) return Z_MEM_ERROR;
54    Tracev((stderr, "inflate: allocated\n"));
55    strm->state = (struct internal_state FAR *)state;
56    state->dmax = 32768U;
57    state->wbits = (uInt)windowBits;
58    state->wsize = 1U << windowBits;
59    state->window = window;
60    state->wnext = 0;
61    state->whave = 0;
62    state->sane = 1;
63    return Z_OK;
64}
65
66/*
67   Return state with length and distance decoding tables and index sizes set to
68   fixed code decoding.  Normally this returns fixed tables from inffixed.h.
69   If BUILDFIXED is defined, then instead this routine builds the tables the
70   first time it's called, and returns those tables the first time and
71   thereafter.  This reduces the size of the code by about 2K bytes, in
72   exchange for a little execution time.  However, BUILDFIXED should not be
73   used for threaded applications, since the rewriting of the tables and virgin
74   may not be thread-safe.
75 */
76local void fixedtables(struct inflate_state FAR *state) {
77#ifdef BUILDFIXED
78    static int virgin = 1;
79    static code *lenfix, *distfix;
80    static code fixed[544];
81
82    /* build fixed huffman tables if first call (may not be thread safe) */
83    if (virgin) {
84        unsigned sym, bits;
85        static code *next;
86
87        /* literal/length table */
88        sym = 0;
89        while (sym < 144) state->lens[sym++] = 8;
90        while (sym < 256) state->lens[sym++] = 9;
91        while (sym < 280) state->lens[sym++] = 7;
92        while (sym < 288) state->lens[sym++] = 8;
93        next = fixed;
94        lenfix = next;
95        bits = 9;
96        inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
97
98        /* distance table */
99        sym = 0;
100        while (sym < 32) state->lens[sym++] = 5;
101        distfix = next;
102        bits = 5;
103        inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
104
105        /* do this just once */
106        virgin = 0;
107    }
108#else /* !BUILDFIXED */
109#   include "inffixed.h"
110#endif /* BUILDFIXED */
111    state->lencode = lenfix;
112    state->lenbits = 9;
113    state->distcode = distfix;
114    state->distbits = 5;
115}
116
117/* Macros for inflateBack(): */
118
119/* Load returned state from inflate_fast() */
120#define LOAD() \
121    do { \
122        put = strm->next_out; \
123        left = strm->avail_out; \
124        next = strm->next_in; \
125        have = strm->avail_in; \
126        hold = state->hold; \
127        bits = state->bits; \
128    } while (0)
129
130/* Set state from registers for inflate_fast() */
131#define RESTORE() \
132    do { \
133        strm->next_out = put; \
134        strm->avail_out = left; \
135        strm->next_in = next; \
136        strm->avail_in = have; \
137        state->hold = hold; \
138        state->bits = bits; \
139    } while (0)
140
141/* Clear the input bit accumulator */
142#define INITBITS() \
143    do { \
144        hold = 0; \
145        bits = 0; \
146    } while (0)
147
148/* Assure that some input is available.  If input is requested, but denied,
149   then return a Z_BUF_ERROR from inflateBack(). */
150#define PULL() \
151    do { \
152        if (have == 0) { \
153            have = in(in_desc, &next); \
154            if (have == 0) { \
155                next = Z_NULL; \
156                ret = Z_BUF_ERROR; \
157                goto inf_leave; \
158            } \
159        } \
160    } while (0)
161
162/* Get a byte of input into the bit accumulator, or return from inflateBack()
163   with an error if there is no input available. */
164#define PULLBYTE() \
165    do { \
166        PULL(); \
167        have--; \
168        hold += (unsigned long)(*next++) << bits; \
169        bits += 8; \
170    } while (0)
171
172/* Assure that there are at least n bits in the bit accumulator.  If there is
173   not enough available input to do that, then return from inflateBack() with
174   an error. */
175#define NEEDBITS(n) \
176    do { \
177        while (bits < (unsigned)(n)) \
178            PULLBYTE(); \
179    } while (0)
180
181/* Return the low n bits of the bit accumulator (n < 16) */
182#define BITS(n) \
183    ((unsigned)hold & ((1U << (n)) - 1))
184
185/* Remove n bits from the bit accumulator */
186#define DROPBITS(n) \
187    do { \
188        hold >>= (n); \
189        bits -= (unsigned)(n); \
190    } while (0)
191
192/* Remove zero to seven bits as needed to go to a byte boundary */
193#define BYTEBITS() \
194    do { \
195        hold >>= bits & 7; \
196        bits -= bits & 7; \
197    } while (0)
198
199/* Assure that some output space is available, by writing out the window
200   if it's full.  If the write fails, return from inflateBack() with a
201   Z_BUF_ERROR. */
202#define ROOM() \
203    do { \
204        if (left == 0) { \
205            put = state->window; \
206            left = state->wsize; \
207            state->whave = left; \
208            if (out(out_desc, put, left)) { \
209                ret = Z_BUF_ERROR; \
210                goto inf_leave; \
211            } \
212        } \
213    } while (0)
214
215/*
216   strm provides the memory allocation functions and window buffer on input,
217   and provides information on the unused input on return.  For Z_DATA_ERROR
218   returns, strm will also provide an error message.
219
220   in() and out() are the call-back input and output functions.  When
221   inflateBack() needs more input, it calls in().  When inflateBack() has
222   filled the window with output, or when it completes with data in the
223   window, it calls out() to write out the data.  The application must not
224   change the provided input until in() is called again or inflateBack()
225   returns.  The application must not change the window/output buffer until
226   inflateBack() returns.
227
228   in() and out() are called with a descriptor parameter provided in the
229   inflateBack() call.  This parameter can be a structure that provides the
230   information required to do the read or write, as well as accumulated
231   information on the input and output such as totals and check values.
232
233   in() should return zero on failure.  out() should return non-zero on
234   failure.  If either in() or out() fails, than inflateBack() returns a
235   Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
236   was in() or out() that caused in the error.  Otherwise,  inflateBack()
237   returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
238   error, or Z_MEM_ERROR if it could not allocate memory for the state.
239   inflateBack() can also return Z_STREAM_ERROR if the input parameters
240   are not correct, i.e. strm is Z_NULL or the state was not initialized.
241 */
242int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
243                        out_func out, void FAR *out_desc) {
244    struct inflate_state FAR *state;
245    z_const unsigned char FAR *next;    /* next input */
246    unsigned char FAR *put;     /* next output */
247    unsigned have, left;        /* available input and output */
248    unsigned long hold;         /* bit buffer */
249    unsigned bits;              /* bits in bit buffer */
250    unsigned copy;              /* number of stored or match bytes to copy */
251    unsigned char FAR *from;    /* where to copy match bytes from */
252    code here;                  /* current decoding table entry */
253    code last;                  /* parent table entry */
254    unsigned len;               /* length to copy for repeats, bits to drop */
255    int ret;                    /* return code */
256    static const unsigned short order[19] = /* permutation of code lengths */
257        {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
258
259    /* Check that the strm exists and that the state was initialized */
260    if (strm == Z_NULL || strm->state == Z_NULL)
261        return Z_STREAM_ERROR;
262    state = (struct inflate_state FAR *)strm->state;
263
264    /* Reset the state */
265    strm->msg = Z_NULL;
266    state->mode = TYPE;
267    state->last = 0;
268    state->whave = 0;
269    next = strm->next_in;
270    have = next != Z_NULL ? strm->avail_in : 0;
271    hold = 0;
272    bits = 0;
273    put = state->window;
274    left = state->wsize;
275
276    /* Inflate until end of block marked as last */
277    for (;;)
278        switch (state->mode) {
279        case TYPE:
280            /* determine and dispatch block type */
281            if (state->last) {
282                BYTEBITS();
283                state->mode = DONE;
284                break;
285            }
286            NEEDBITS(3);
287            state->last = BITS(1);
288            DROPBITS(1);
289            switch (BITS(2)) {
290            case 0:                             /* stored block */
291                Tracev((stderr, "inflate:     stored block%s\n",
292                        state->last ? " (last)" : ""));
293                state->mode = STORED;
294                break;
295            case 1:                             /* fixed block */
296                fixedtables(state);
297                Tracev((stderr, "inflate:     fixed codes block%s\n",
298                        state->last ? " (last)" : ""));
299                state->mode = LEN;              /* decode codes */
300                break;
301            case 2:                             /* dynamic block */
302                Tracev((stderr, "inflate:     dynamic codes block%s\n",
303                        state->last ? " (last)" : ""));
304                state->mode = TABLE;
305                break;
306            case 3:
307#ifdef SMALL
308                strm->msg = "error";
309#else
310                strm->msg = (char *)"invalid block type";
311#endif
312                state->mode = BAD;
313            }
314            DROPBITS(2);
315            break;
316
317        case STORED:
318            /* get and verify stored block length */
319            BYTEBITS();                         /* go to byte boundary */
320            NEEDBITS(32);
321            if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
322#ifdef SMALL
323                strm->msg = "error";
324#else
325                strm->msg = (char *)"invalid stored block lengths";
326#endif
327                state->mode = BAD;
328                break;
329            }
330            state->length = (unsigned)hold & 0xffff;
331            Tracev((stderr, "inflate:       stored length %u\n",
332                    state->length));
333            INITBITS();
334
335            /* copy stored block from input to output */
336            while (state->length != 0) {
337                copy = state->length;
338                PULL();
339                ROOM();
340                if (copy > have) copy = have;
341                if (copy > left) copy = left;
342                zmemcpy(put, next, copy);
343                have -= copy;
344                next += copy;
345                left -= copy;
346                put += copy;
347                state->length -= copy;
348            }
349            Tracev((stderr, "inflate:       stored end\n"));
350            state->mode = TYPE;
351            break;
352
353        case TABLE:
354            /* get dynamic table entries descriptor */
355            NEEDBITS(14);
356            state->nlen = BITS(5) + 257;
357            DROPBITS(5);
358            state->ndist = BITS(5) + 1;
359            DROPBITS(5);
360            state->ncode = BITS(4) + 4;
361            DROPBITS(4);
362#ifndef PKZIP_BUG_WORKAROUND
363            if (state->nlen > 286 || state->ndist > 30) {
364#ifdef SMALL
365                strm->msg = "error";
366#else
367                strm->msg = (char *)"too many length or distance symbols";
368#endif
369                state->mode = BAD;
370                break;
371            }
372#endif
373            Tracev((stderr, "inflate:       table sizes ok\n"));
374
375            /* get code length code lengths (not a typo) */
376            state->have = 0;
377            while (state->have < state->ncode) {
378                NEEDBITS(3);
379                state->lens[order[state->have++]] = (unsigned short)BITS(3);
380                DROPBITS(3);
381            }
382            while (state->have < 19)
383                state->lens[order[state->have++]] = 0;
384            state->next = state->codes;
385            state->lencode = (code const FAR *)(state->next);
386            state->lenbits = 7;
387            ret = inflate_table(CODES, state->lens, 19, &(state->next),
388                                &(state->lenbits), state->work);
389            if (ret) {
390#ifdef SMALL
391                strm->msg = "error";
392#else
393                strm->msg = (char *)"invalid code lengths set";
394#endif
395                state->mode = BAD;
396                break;
397            }
398            Tracev((stderr, "inflate:       code lengths ok\n"));
399
400            /* get length and distance code code lengths */
401            state->have = 0;
402            while (state->have < state->nlen + state->ndist) {
403                for (;;) {
404                    here = state->lencode[BITS(state->lenbits)];
405                    if ((unsigned)(here.bits) <= bits) break;
406                    PULLBYTE();
407                }
408                if (here.val < 16) {
409                    DROPBITS(here.bits);
410                    state->lens[state->have++] = here.val;
411                }
412                else {
413                    if (here.val == 16) {
414                        NEEDBITS(here.bits + 2);
415                        DROPBITS(here.bits);
416                        if (state->have == 0) {
417#ifdef SMALL
418                            strm->msg = "error";
419#else
420                            strm->msg = (char *)"invalid bit length repeat";
421#endif
422                            state->mode = BAD;
423                            break;
424                        }
425                        len = (unsigned)(state->lens[state->have - 1]);
426                        copy = 3 + BITS(2);
427                        DROPBITS(2);
428                    }
429                    else if (here.val == 17) {
430                        NEEDBITS(here.bits + 3);
431                        DROPBITS(here.bits);
432                        len = 0;
433                        copy = 3 + BITS(3);
434                        DROPBITS(3);
435                    }
436                    else {
437                        NEEDBITS(here.bits + 7);
438                        DROPBITS(here.bits);
439                        len = 0;
440                        copy = 11 + BITS(7);
441                        DROPBITS(7);
442                    }
443                    if (state->have + copy > state->nlen + state->ndist) {
444#ifdef SMALL
445                        strm->msg = "error";
446#else
447                        strm->msg = (char *)"invalid bit length repeat";
448#endif
449                        state->mode = BAD;
450                        break;
451                    }
452                    while (copy--)
453                        state->lens[state->have++] = (unsigned short)len;
454                }
455            }
456
457            /* handle error breaks in while */
458            if (state->mode == BAD) break;
459
460            /* check for end-of-block code (better have one) */
461            if (state->lens[256] == 0) {
462#ifdef SMALL
463                strm->msg = "error";
464#else
465                strm->msg = (char *)"invalid code -- missing end-of-block";
466#endif
467                state->mode = BAD;
468                break;
469            }
470
471            /* build code tables -- note: do not change the lenbits or distbits
472               values here (9 and 6) without reading the comments in inftrees.h
473               concerning the ENOUGH constants, which depend on those values */
474            state->next = state->codes;
475            state->lencode = (code const FAR *)(state->next);
476            state->lenbits = 9;
477            ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
478                                &(state->lenbits), state->work);
479            if (ret) {
480#ifdef SMALL
481                strm->msg = "error";
482#else
483                strm->msg = (char *)"invalid literal/lengths set";
484#endif
485                state->mode = BAD;
486                break;
487            }
488            state->distcode = (code const FAR *)(state->next);
489            state->distbits = 6;
490            ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
491                            &(state->next), &(state->distbits), state->work);
492            if (ret) {
493#ifdef SMALL
494                strm->msg = "error";
495#else
496                strm->msg = (char *)"invalid distances set";
497#endif
498                state->mode = BAD;
499                break;
500            }
501            Tracev((stderr, "inflate:       codes ok\n"));
502            state->mode = LEN;
503                /* fallthrough */
504
505        case LEN:
506#ifndef SLOW
507            /* use inflate_fast() if we have enough input and output */
508            if (have >= 6 && left >= 258) {
509                RESTORE();
510                if (state->whave < state->wsize)
511                    state->whave = state->wsize - left;
512                inflate_fast(strm, state->wsize);
513                LOAD();
514                break;
515            }
516#endif
517
518            /* get a literal, length, or end-of-block code */
519            for (;;) {
520                here = state->lencode[BITS(state->lenbits)];
521                if ((unsigned)(here.bits) <= bits) break;
522                PULLBYTE();
523            }
524            if (here.op && (here.op & 0xf0) == 0) {
525                last = here;
526                for (;;) {
527                    here = state->lencode[last.val +
528                            (BITS(last.bits + last.op) >> last.bits)];
529                    if ((unsigned)(last.bits + here.bits) <= bits) break;
530                    PULLBYTE();
531                }
532                DROPBITS(last.bits);
533            }
534            DROPBITS(here.bits);
535            state->length = (unsigned)here.val;
536
537            /* process literal */
538            if (here.op == 0) {
539                Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
540                        "inflate:         literal '%c'\n" :
541                        "inflate:         literal 0x%02x\n", here.val));
542                ROOM();
543                *put++ = (unsigned char)(state->length);
544                left--;
545                state->mode = LEN;
546                break;
547            }
548
549            /* process end of block */
550            if (here.op & 32) {
551                Tracevv((stderr, "inflate:         end of block\n"));
552                state->mode = TYPE;
553                break;
554            }
555
556            /* invalid code */
557            if (here.op & 64) {
558#ifdef SMALL
559                strm->msg = "error";
560#else
561                strm->msg = (char *)"invalid literal/length code";
562#endif
563                state->mode = BAD;
564                break;
565            }
566
567            /* length code -- get extra bits, if any */
568            state->extra = (unsigned)(here.op) & 15;
569            if (state->extra != 0) {
570                NEEDBITS(state->extra);
571                state->length += BITS(state->extra);
572                DROPBITS(state->extra);
573            }
574            Tracevv((stderr, "inflate:         length %u\n", state->length));
575
576            /* get distance code */
577            for (;;) {
578                here = state->distcode[BITS(state->distbits)];
579                if ((unsigned)(here.bits) <= bits) break;
580                PULLBYTE();
581            }
582            if ((here.op & 0xf0) == 0) {
583                last = here;
584                for (;;) {
585                    here = state->distcode[last.val +
586                            (BITS(last.bits + last.op) >> last.bits)];
587                    if ((unsigned)(last.bits + here.bits) <= bits) break;
588                    PULLBYTE();
589                }
590                DROPBITS(last.bits);
591            }
592            DROPBITS(here.bits);
593            if (here.op & 64) {
594#ifdef SMALL
595                strm->msg = "error";
596#else
597                strm->msg = (char *)"invalid distance code";
598#endif
599                state->mode = BAD;
600                break;
601            }
602            state->offset = (unsigned)here.val;
603
604            /* get distance extra bits, if any */
605            state->extra = (unsigned)(here.op) & 15;
606            if (state->extra != 0) {
607                NEEDBITS(state->extra);
608                state->offset += BITS(state->extra);
609                DROPBITS(state->extra);
610            }
611            if (state->offset > state->wsize - (state->whave < state->wsize ?
612                                                left : 0)) {
613#ifdef SMALL
614                strm->msg = "error";
615#else
616                strm->msg = (char *)"invalid distance too far back";
617#endif
618                state->mode = BAD;
619                break;
620            }
621            Tracevv((stderr, "inflate:         distance %u\n", state->offset));
622
623            /* copy match from window to output */
624            do {
625                ROOM();
626                copy = state->wsize - state->offset;
627                if (copy < left) {
628                    from = put + copy;
629                    copy = left - copy;
630                }
631                else {
632                    from = put - state->offset;
633                    copy = left;
634                }
635                if (copy > state->length) copy = state->length;
636                state->length -= copy;
637                left -= copy;
638                do {
639                    *put++ = *from++;
640                } while (--copy);
641            } while (state->length != 0);
642            break;
643
644        case DONE:
645            /* inflate stream terminated properly */
646            ret = Z_STREAM_END;
647            goto inf_leave;
648
649        case BAD:
650            ret = Z_DATA_ERROR;
651            goto inf_leave;
652
653        default:
654            /* can't happen, but makes compilers happy */
655            ret = Z_STREAM_ERROR;
656            goto inf_leave;
657        }
658
659    /* Write leftover output and return unused input */
660  inf_leave:
661    if (left < state->wsize) {
662        if (out(out_desc, state->window, state->wsize - left) &&
663            ret == Z_STREAM_END)
664            ret = Z_BUF_ERROR;
665    }
666    strm->next_in = next;
667    strm->avail_in = have;
668    return ret;
669}
670
671int ZEXPORT inflateBackEnd(z_streamp strm) {
672    if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
673        return Z_STREAM_ERROR;
674    ZFREE(strm, strm->state, sizeof(struct inflate_state));
675    strm->state = Z_NULL;
676    Tracev((stderr, "inflate: end\n"));
677    return Z_OK;
678}
679