infback.c revision 134354
1/* infback.c -- inflate using a call-back interface
2 * Copyright (C) 1995-2003 Mark Adler
3 * For conditions of distribution and use, see copyright notice in zlib.h
4 */
5
6/*
7   This code is largely copied from inflate.c.  Normally either infback.o or
8   inflate.o would be linked into an application--not both.  The interface
9   with inffast.c is retained so that optimized assembler-coded versions of
10   inflate_fast() can be used with either inflate.c or infback.c.
11 */
12
13#include "zutil.h"
14#include "inftrees.h"
15#include "inflate.h"
16#include "inffast.h"
17
18/* function prototypes */
19local void fixedtables OF((struct inflate_state FAR *state));
20
21/*
22   strm provides memory allocation functions in zalloc and zfree, or
23   Z_NULL to use the library memory allocation functions.
24
25   windowBits is in the range 8..15, and window is a user-supplied
26   window and output buffer that is 2**windowBits bytes.
27 */
28int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29z_stream FAR *strm;
30int windowBits;
31unsigned char FAR *window;
32const char *version;
33int stream_size;
34{
35    struct inflate_state FAR *state;
36
37    if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38        stream_size != (int)(sizeof(z_stream)))
39        return Z_VERSION_ERROR;
40    if (strm == Z_NULL || window == Z_NULL ||
41        windowBits < 8 || windowBits > 15)
42        return Z_STREAM_ERROR;
43    strm->msg = Z_NULL;                 /* in case we return an error */
44    if (strm->zalloc == (alloc_func)0) {
45        strm->zalloc = zcalloc;
46        strm->opaque = (voidpf)0;
47    }
48    if (strm->zfree == (free_func)0) strm->zfree = zcfree;
49    state = (struct inflate_state FAR *)ZALLOC(strm, 1,
50                                               sizeof(struct inflate_state));
51    if (state == Z_NULL) return Z_MEM_ERROR;
52    Tracev((stderr, "inflate: allocated\n"));
53    strm->state = (voidpf)state;
54    state->wbits = windowBits;
55    state->wsize = 1U << windowBits;
56    state->window = window;
57    state->write = 0;
58    state->whave = 0;
59    return Z_OK;
60}
61
62/*
63   Return state with length and distance decoding tables and index sizes set to
64   fixed code decoding.  Normally this returns fixed tables from inffixed.h.
65   If BUILDFIXED is defined, then instead this routine builds the tables the
66   first time it's called, and returns those tables the first time and
67   thereafter.  This reduces the size of the code by about 2K bytes, in
68   exchange for a little execution time.  However, BUILDFIXED should not be
69   used for threaded applications, since the rewriting of the tables and virgin
70   may not be thread-safe.
71 */
72local void fixedtables(state)
73struct inflate_state FAR *state;
74{
75#ifdef BUILDFIXED
76    static int virgin = 1;
77    static code *lenfix, *distfix;
78    static code fixed[544];
79
80    /* build fixed huffman tables if first call (may not be thread safe) */
81    if (virgin) {
82        unsigned sym, bits;
83        static code *next;
84
85        /* literal/length table */
86        sym = 0;
87        while (sym < 144) state->lens[sym++] = 8;
88        while (sym < 256) state->lens[sym++] = 9;
89        while (sym < 280) state->lens[sym++] = 7;
90        while (sym < 288) state->lens[sym++] = 8;
91        next = fixed;
92        lenfix = next;
93        bits = 9;
94        inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
95
96        /* distance table */
97        sym = 0;
98        while (sym < 32) state->lens[sym++] = 5;
99        distfix = next;
100        bits = 5;
101        inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
102
103        /* do this just once */
104        virgin = 0;
105    }
106#else /* !BUILDFIXED */
107#   include "inffixed.h"
108#endif /* BUILDFIXED */
109    state->lencode = lenfix;
110    state->lenbits = 9;
111    state->distcode = distfix;
112    state->distbits = 5;
113}
114
115/* Macros for inflateBack(): */
116
117/* Load returned state from inflate_fast() */
118#define LOAD() \
119    do { \
120        put = strm->next_out; \
121        left = strm->avail_out; \
122        next = strm->next_in; \
123        have = strm->avail_in; \
124        hold = state->hold; \
125        bits = state->bits; \
126    } while (0)
127
128/* Set state from registers for inflate_fast() */
129#define RESTORE() \
130    do { \
131        strm->next_out = put; \
132        strm->avail_out = left; \
133        strm->next_in = next; \
134        strm->avail_in = have; \
135        state->hold = hold; \
136        state->bits = bits; \
137    } while (0)
138
139/* Clear the input bit accumulator */
140#define INITBITS() \
141    do { \
142        hold = 0; \
143        bits = 0; \
144    } while (0)
145
146/* Assure that some input is available.  If input is requested, but denied,
147   then return a Z_BUF_ERROR from inflateBack(). */
148#define PULL() \
149    do { \
150        if (have == 0) { \
151            have = in(in_desc, &next); \
152            if (have == 0) { \
153                next = Z_NULL; \
154                ret = Z_BUF_ERROR; \
155                goto inf_leave; \
156            } \
157        } \
158    } while (0)
159
160/* Get a byte of input into the bit accumulator, or return from inflateBack()
161   with an error if there is no input available. */
162#define PULLBYTE() \
163    do { \
164        PULL(); \
165        have--; \
166        hold += (unsigned long)(*next++) << bits; \
167        bits += 8; \
168    } while (0)
169
170/* Assure that there are at least n bits in the bit accumulator.  If there is
171   not enough available input to do that, then return from inflateBack() with
172   an error. */
173#define NEEDBITS(n) \
174    do { \
175        while (bits < (unsigned)(n)) \
176            PULLBYTE(); \
177    } while (0)
178
179/* Return the low n bits of the bit accumulator (n < 16) */
180#define BITS(n) \
181    ((unsigned)hold & ((1U << (n)) - 1))
182
183/* Remove n bits from the bit accumulator */
184#define DROPBITS(n) \
185    do { \
186        hold >>= (n); \
187        bits -= (unsigned)(n); \
188    } while (0)
189
190/* Remove zero to seven bits as needed to go to a byte boundary */
191#define BYTEBITS() \
192    do { \
193        hold >>= bits & 7; \
194        bits -= bits & 7; \
195    } while (0)
196
197/* Assure that some output space is available, by writing out the window
198   if it's full.  If the write fails, return from inflateBack() with a
199   Z_BUF_ERROR. */
200#define ROOM() \
201    do { \
202        if (left == 0) { \
203            put = state->window; \
204            left = state->wsize; \
205            state->whave = left; \
206            if (out(out_desc, put, left)) { \
207                ret = Z_BUF_ERROR; \
208                goto inf_leave; \
209            } \
210        } \
211    } while (0)
212
213/*
214   strm provides the memory allocation functions and window buffer on input,
215   and provides information on the unused input on return.  For Z_DATA_ERROR
216   returns, strm will also provide an error message.
217
218   in() and out() are the call-back input and output functions.  When
219   inflateBack() needs more input, it calls in().  When inflateBack() has
220   filled the window with output, or when it completes with data in the
221   window, it calls out() to write out the data.  The application must not
222   change the provided input until in() is called again or inflateBack()
223   returns.  The application must not change the window/output buffer until
224   inflateBack() returns.
225
226   in() and out() are called with a descriptor parameter provided in the
227   inflateBack() call.  This parameter can be a structure that provides the
228   information required to do the read or write, as well as accumulated
229   information on the input and output such as totals and check values.
230
231   in() should return zero on failure.  out() should return non-zero on
232   failure.  If either in() or out() fails, than inflateBack() returns a
233   Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
234   was in() or out() that caused in the error.  Otherwise,  inflateBack()
235   returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
236   error, or Z_MEM_ERROR if it could not allocate memory for the state.
237   inflateBack() can also return Z_STREAM_ERROR if the input parameters
238   are not correct, i.e. strm is Z_NULL or the state was not initialized.
239 */
240int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
241z_stream FAR *strm;
242in_func in;
243void FAR *in_desc;
244out_func out;
245void FAR *out_desc;
246{
247    struct inflate_state FAR *state;
248    unsigned char FAR *next;    /* next input */
249    unsigned char FAR *put;     /* next output */
250    unsigned have, left;        /* available input and output */
251    unsigned long hold;         /* bit buffer */
252    unsigned bits;              /* bits in bit buffer */
253    unsigned copy;              /* number of stored or match bytes to copy */
254    unsigned char FAR *from;    /* where to copy match bytes from */
255    code this;                  /* current decoding table entry */
256    code last;                  /* parent table entry */
257    unsigned len;               /* length to copy for repeats, bits to drop */
258    int ret;                    /* return code */
259    static const unsigned short order[19] = /* permutation of code lengths */
260        {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
261
262    /* Check that the strm exists and that the state was initialized */
263    if (strm == Z_NULL || strm->state == Z_NULL)
264        return Z_STREAM_ERROR;
265    state = (struct inflate_state FAR *)strm->state;
266
267    /* Reset the state */
268    strm->msg = Z_NULL;
269    state->mode = TYPE;
270    state->last = 0;
271    state->whave = 0;
272    next = strm->next_in;
273    have = next != Z_NULL ? strm->avail_in : 0;
274    hold = 0;
275    bits = 0;
276    put = state->window;
277    left = state->wsize;
278
279    /* Inflate until end of block marked as last */
280    for (;;)
281        switch (state->mode) {
282        case TYPE:
283            /* determine and dispatch block type */
284            if (state->last) {
285                BYTEBITS();
286                state->mode = DONE;
287                break;
288            }
289            NEEDBITS(3);
290            state->last = BITS(1);
291            DROPBITS(1);
292            switch (BITS(2)) {
293            case 0:                             /* stored block */
294                Tracev((stderr, "inflate:     stored block%s\n",
295                        state->last ? " (last)" : ""));
296                state->mode = STORED;
297                break;
298            case 1:                             /* fixed block */
299                fixedtables(state);
300                Tracev((stderr, "inflate:     fixed codes block%s\n",
301                        state->last ? " (last)" : ""));
302                state->mode = LEN;              /* decode codes */
303                break;
304            case 2:                             /* dynamic block */
305                Tracev((stderr, "inflate:     dynamic codes block%s\n",
306                        state->last ? " (last)" : ""));
307                state->mode = TABLE;
308                break;
309            case 3:
310                strm->msg = (char *)"invalid block type";
311                state->mode = BAD;
312            }
313            DROPBITS(2);
314            break;
315
316        case STORED:
317            /* get and verify stored block length */
318            BYTEBITS();                         /* go to byte boundary */
319            NEEDBITS(32);
320            if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
321                strm->msg = (char *)"invalid stored block lengths";
322                state->mode = BAD;
323                break;
324            }
325            state->length = (unsigned)hold & 0xffff;
326            Tracev((stderr, "inflate:       stored length %u\n",
327                    state->length));
328            INITBITS();
329
330            /* copy stored block from input to output */
331            while (state->length != 0) {
332                copy = state->length;
333                PULL();
334                ROOM();
335                if (copy > have) copy = have;
336                if (copy > left) copy = left;
337                zmemcpy(put, next, copy);
338                have -= copy;
339                next += copy;
340                left -= copy;
341                put += copy;
342                state->length -= copy;
343            }
344            Tracev((stderr, "inflate:       stored end\n"));
345            state->mode = TYPE;
346            break;
347
348        case TABLE:
349            /* get dynamic table entries descriptor */
350            NEEDBITS(14);
351            state->nlen = BITS(5) + 257;
352            DROPBITS(5);
353            state->ndist = BITS(5) + 1;
354            DROPBITS(5);
355            state->ncode = BITS(4) + 4;
356            DROPBITS(4);
357#ifndef PKZIP_BUG_WORKAROUND
358            if (state->nlen > 286 || state->ndist > 30) {
359                strm->msg = (char *)"too many length or distance symbols";
360                state->mode = BAD;
361                break;
362            }
363#endif
364            Tracev((stderr, "inflate:       table sizes ok\n"));
365
366            /* get code length code lengths (not a typo) */
367            state->have = 0;
368            while (state->have < state->ncode) {
369                NEEDBITS(3);
370                state->lens[order[state->have++]] = (unsigned short)BITS(3);
371                DROPBITS(3);
372            }
373            while (state->have < 19)
374                state->lens[order[state->have++]] = 0;
375            state->next = state->codes;
376            state->lencode = (code const FAR *)(state->next);
377            state->lenbits = 7;
378            ret = inflate_table(CODES, state->lens, 19, &(state->next),
379                                &(state->lenbits), state->work);
380            if (ret) {
381                strm->msg = (char *)"invalid code lengths set";
382                state->mode = BAD;
383                break;
384            }
385            Tracev((stderr, "inflate:       code lengths ok\n"));
386
387            /* get length and distance code code lengths */
388            state->have = 0;
389            while (state->have < state->nlen + state->ndist) {
390                for (;;) {
391                    this = state->lencode[BITS(state->lenbits)];
392                    if ((unsigned)(this.bits) <= bits) break;
393                    PULLBYTE();
394                }
395                if (this.val < 16) {
396                    NEEDBITS(this.bits);
397                    DROPBITS(this.bits);
398                    state->lens[state->have++] = this.val;
399                }
400                else {
401                    if (this.val == 16) {
402                        NEEDBITS(this.bits + 2);
403                        DROPBITS(this.bits);
404                        if (state->have == 0) {
405                            strm->msg = (char *)"invalid bit length repeat";
406                            state->mode = BAD;
407                            break;
408                        }
409                        len = (unsigned)(state->lens[state->have - 1]);
410                        copy = 3 + BITS(2);
411                        DROPBITS(2);
412                    }
413                    else if (this.val == 17) {
414                        NEEDBITS(this.bits + 3);
415                        DROPBITS(this.bits);
416                        len = 0;
417                        copy = 3 + BITS(3);
418                        DROPBITS(3);
419                    }
420                    else {
421                        NEEDBITS(this.bits + 7);
422                        DROPBITS(this.bits);
423                        len = 0;
424                        copy = 11 + BITS(7);
425                        DROPBITS(7);
426                    }
427                    if (state->have + copy > state->nlen + state->ndist) {
428                        strm->msg = (char *)"invalid bit length repeat";
429                        state->mode = BAD;
430                        break;
431                    }
432                    while (copy--)
433                        state->lens[state->have++] = (unsigned short)len;
434                }
435            }
436
437            if (state->mode == BAD)
438                break;
439
440            /* build code tables */
441            state->next = state->codes;
442            state->lencode = (code const FAR *)(state->next);
443            state->lenbits = 9;
444            ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
445                                &(state->lenbits), state->work);
446            if (ret) {
447                strm->msg = (char *)"invalid literal/lengths set";
448                state->mode = BAD;
449                break;
450            }
451            state->distcode = (code const FAR *)(state->next);
452            state->distbits = 6;
453            ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
454                            &(state->next), &(state->distbits), state->work);
455            if (ret) {
456                strm->msg = (char *)"invalid distances set";
457                state->mode = BAD;
458                break;
459            }
460            Tracev((stderr, "inflate:       codes ok\n"));
461            state->mode = LEN;
462
463        case LEN:
464            /* use inflate_fast() if we have enough input and output */
465            if (have >= 6 && left >= 258) {
466                RESTORE();
467                if (state->whave < state->wsize)
468                    state->whave = state->wsize - left;
469                inflate_fast(strm, state->wsize);
470                LOAD();
471                break;
472            }
473
474            /* get a literal, length, or end-of-block code */
475            for (;;) {
476                this = state->lencode[BITS(state->lenbits)];
477                if ((unsigned)(this.bits) <= bits) break;
478                PULLBYTE();
479            }
480            if (this.op && (this.op & 0xf0) == 0) {
481                last = this;
482                for (;;) {
483                    this = state->lencode[last.val +
484                            (BITS(last.bits + last.op) >> last.bits)];
485                    if ((unsigned)(last.bits + this.bits) <= bits) break;
486                    PULLBYTE();
487                }
488                DROPBITS(last.bits);
489            }
490            DROPBITS(this.bits);
491            state->length = (unsigned)this.val;
492
493            /* process literal */
494            if (this.op == 0) {
495                Tracevv((stderr, this.val >= 0x20 && this.val < 0x7f ?
496                        "inflate:         literal '%c'\n" :
497                        "inflate:         literal 0x%02x\n", this.val));
498                ROOM();
499                *put++ = (unsigned char)(state->length);
500                left--;
501                state->mode = LEN;
502                break;
503            }
504
505            /* process end of block */
506            if (this.op & 32) {
507                Tracevv((stderr, "inflate:         end of block\n"));
508                state->mode = TYPE;
509                break;
510            }
511
512            /* invalid code */
513            if (this.op & 64) {
514                strm->msg = (char *)"invalid literal/length code";
515                state->mode = BAD;
516                break;
517            }
518
519            /* length code -- get extra bits, if any */
520            state->extra = (unsigned)(this.op) & 15;
521            if (state->extra != 0) {
522                NEEDBITS(state->extra);
523                state->length += BITS(state->extra);
524                DROPBITS(state->extra);
525            }
526            Tracevv((stderr, "inflate:         length %u\n", state->length));
527
528            /* get distance code */
529            for (;;) {
530                this = state->distcode[BITS(state->distbits)];
531                if ((unsigned)(this.bits) <= bits) break;
532                PULLBYTE();
533            }
534            if ((this.op & 0xf0) == 0) {
535                last = this;
536                for (;;) {
537                    this = state->distcode[last.val +
538                            (BITS(last.bits + last.op) >> last.bits)];
539                    if ((unsigned)(last.bits + this.bits) <= bits) break;
540                    PULLBYTE();
541                }
542                DROPBITS(last.bits);
543            }
544            DROPBITS(this.bits);
545            if (this.op & 64) {
546                strm->msg = (char *)"invalid distance code";
547                state->mode = BAD;
548                break;
549            }
550            state->offset = (unsigned)this.val;
551
552            /* get distance extra bits, if any */
553            state->extra = (unsigned)(this.op) & 15;
554            if (state->extra != 0) {
555                NEEDBITS(state->extra);
556                state->offset += BITS(state->extra);
557                DROPBITS(state->extra);
558            }
559            if (state->offset > state->wsize - (state->whave < state->wsize ?
560                                                left : 0)) {
561                strm->msg = (char *)"invalid distance too far back";
562                state->mode = BAD;
563                break;
564            }
565            Tracevv((stderr, "inflate:         distance %u\n", state->offset));
566
567            /* copy match from window to output */
568            do {
569                ROOM();
570                copy = state->wsize - state->offset;
571                if (copy < left) {
572                    from = put + copy;
573                    copy = left - copy;
574                }
575                else {
576                    from = put - state->offset;
577                    copy = left;
578                }
579                if (copy > state->length) copy = state->length;
580                state->length -= copy;
581                left -= copy;
582                do {
583                    *put++ = *from++;
584                } while (--copy);
585            } while (state->length != 0);
586            break;
587
588        case DONE:
589            /* inflate stream terminated properly -- write leftover output */
590            ret = Z_STREAM_END;
591            if (left < state->wsize) {
592                if (out(out_desc, state->window, state->wsize - left))
593                    ret = Z_BUF_ERROR;
594            }
595            goto inf_leave;
596
597        case BAD:
598            ret = Z_DATA_ERROR;
599            goto inf_leave;
600
601        default:                /* can't happen, but makes compilers happy */
602            ret = Z_STREAM_ERROR;
603            goto inf_leave;
604        }
605
606    /* Return unused input */
607  inf_leave:
608    strm->next_in = next;
609    strm->avail_in = have;
610    return ret;
611}
612
613int ZEXPORT inflateBackEnd(strm)
614z_stream FAR *strm;
615{
616    if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
617        return Z_STREAM_ERROR;
618    ZFREE(strm, strm->state);
619    strm->state = Z_NULL;
620    Tracev((stderr, "inflate: end\n"));
621    return Z_OK;
622}
623