1/*	$NetBSD: infback.c,v 1.1.1.1 2006/01/14 20:10:29 christos Exp $	*/
2
3/* infback.c -- inflate using a call-back interface
4 * Copyright (C) 1995-2005 Mark Adler
5 * For conditions of distribution and use, see copyright notice in zlib.h
6 */
7
8/*
9   This code is largely copied from inflate.c.  Normally either infback.o or
10   inflate.o would be linked into an application--not both.  The interface
11   with inffast.c is retained so that optimized assembler-coded versions of
12   inflate_fast() can be used with either inflate.c or infback.c.
13 */
14
15#include "zutil.h"
16#include "inftrees.h"
17#include "inflate.h"
18#include "inffast.h"
19
20/* function prototypes */
21local void fixedtables OF((struct inflate_state FAR *state));
22
23/*
24   strm provides memory allocation functions in zalloc and zfree, or
25   Z_NULL to use the library memory allocation functions.
26
27   windowBits is in the range 8..15, and window is a user-supplied
28   window and output buffer that is 2**windowBits bytes.
29 */
30int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
31z_streamp strm;
32int windowBits;
33unsigned char FAR *window;
34const char *version;
35int stream_size;
36{
37    struct inflate_state FAR *state;
38
39    if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
40        stream_size != (int)(sizeof(z_stream)))
41        return Z_VERSION_ERROR;
42    if (strm == Z_NULL || window == Z_NULL ||
43        windowBits < 8 || windowBits > 15)
44        return Z_STREAM_ERROR;
45    strm->msg = Z_NULL;                 /* in case we return an error */
46    if (strm->zalloc == (alloc_func)0) {
47        strm->zalloc = zcalloc;
48        strm->opaque = (voidpf)0;
49    }
50    if (strm->zfree == (free_func)0) strm->zfree = zcfree;
51    state = (struct inflate_state FAR *)ZALLOC(strm, 1,
52                                               sizeof(struct inflate_state));
53    if (state == Z_NULL) return Z_MEM_ERROR;
54    Tracev((stderr, "inflate: allocated\n"));
55    strm->state = (struct internal_state FAR *)state;
56    state->dmax = 32768U;
57    state->wbits = windowBits;
58    state->wsize = 1U << windowBits;
59    state->window = window;
60    state->write = 0;
61    state->whave = 0;
62    return Z_OK;
63}
64
65/*
66   Return state with length and distance decoding tables and index sizes set to
67   fixed code decoding.  Normally this returns fixed tables from inffixed.h.
68   If BUILDFIXED is defined, then instead this routine builds the tables the
69   first time it's called, and returns those tables the first time and
70   thereafter.  This reduces the size of the code by about 2K bytes, in
71   exchange for a little execution time.  However, BUILDFIXED should not be
72   used for threaded applications, since the rewriting of the tables and virgin
73   may not be thread-safe.
74 */
75local void fixedtables(state)
76struct inflate_state FAR *state;
77{
78#ifdef BUILDFIXED
79    static int virgin = 1;
80    static code *lenfix, *distfix;
81    static code fixed[544];
82
83    /* build fixed huffman tables if first call (may not be thread safe) */
84    if (virgin) {
85        unsigned sym, bits;
86        static code *next;
87
88        /* literal/length table */
89        sym = 0;
90        while (sym < 144) state->lens[sym++] = 8;
91        while (sym < 256) state->lens[sym++] = 9;
92        while (sym < 280) state->lens[sym++] = 7;
93        while (sym < 288) state->lens[sym++] = 8;
94        next = fixed;
95        lenfix = next;
96        bits = 9;
97        inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
98
99        /* distance table */
100        sym = 0;
101        while (sym < 32) state->lens[sym++] = 5;
102        distfix = next;
103        bits = 5;
104        inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
105
106        /* do this just once */
107        virgin = 0;
108    }
109#else /* !BUILDFIXED */
110#   include "inffixed.h"
111#endif /* BUILDFIXED */
112    state->lencode = lenfix;
113    state->lenbits = 9;
114    state->distcode = distfix;
115    state->distbits = 5;
116}
117
118/* Macros for inflateBack(): */
119
120/* Load returned state from inflate_fast() */
121#define LOAD() \
122    do { \
123        put = strm->next_out; \
124        left = strm->avail_out; \
125        next = strm->next_in; \
126        have = strm->avail_in; \
127        hold = state->hold; \
128        bits = state->bits; \
129    } while (0)
130
131/* Set state from registers for inflate_fast() */
132#define RESTORE() \
133    do { \
134        strm->next_out = put; \
135        strm->avail_out = left; \
136        strm->next_in = next; \
137        strm->avail_in = have; \
138        state->hold = hold; \
139        state->bits = bits; \
140    } while (0)
141
142/* Clear the input bit accumulator */
143#define INITBITS() \
144    do { \
145        hold = 0; \
146        bits = 0; \
147    } while (0)
148
149/* Assure that some input is available.  If input is requested, but denied,
150   then return a Z_BUF_ERROR from inflateBack(). */
151#define PULL() \
152    do { \
153        if (have == 0) { \
154            have = in(in_desc, &next); \
155            if (have == 0) { \
156                next = Z_NULL; \
157                ret = Z_BUF_ERROR; \
158                goto inf_leave; \
159            } \
160        } \
161    } while (0)
162
163/* Get a byte of input into the bit accumulator, or return from inflateBack()
164   with an error if there is no input available. */
165#define PULLBYTE() \
166    do { \
167        PULL(); \
168        have--; \
169        hold += (unsigned long)(*next++) << bits; \
170        bits += 8; \
171    } while (0)
172
173/* Assure that there are at least n bits in the bit accumulator.  If there is
174   not enough available input to do that, then return from inflateBack() with
175   an error. */
176#define NEEDBITS(n) \
177    do { \
178        while (bits < (unsigned)(n)) \
179            PULLBYTE(); \
180    } while (0)
181
182/* Return the low n bits of the bit accumulator (n < 16) */
183#define BITS(n) \
184    ((unsigned)hold & ((1U << (n)) - 1))
185
186/* Remove n bits from the bit accumulator */
187#define DROPBITS(n) \
188    do { \
189        hold >>= (n); \
190        bits -= (unsigned)(n); \
191    } while (0)
192
193/* Remove zero to seven bits as needed to go to a byte boundary */
194#define BYTEBITS() \
195    do { \
196        hold >>= bits & 7; \
197        bits -= bits & 7; \
198    } while (0)
199
200/* Assure that some output space is available, by writing out the window
201   if it's full.  If the write fails, return from inflateBack() with a
202   Z_BUF_ERROR. */
203#define ROOM() \
204    do { \
205        if (left == 0) { \
206            put = state->window; \
207            left = state->wsize; \
208            state->whave = left; \
209            if (out(out_desc, put, left)) { \
210                ret = Z_BUF_ERROR; \
211                goto inf_leave; \
212            } \
213        } \
214    } while (0)
215
216/*
217   strm provides the memory allocation functions and window buffer on input,
218   and provides information on the unused input on return.  For Z_DATA_ERROR
219   returns, strm will also provide an error message.
220
221   in() and out() are the call-back input and output functions.  When
222   inflateBack() needs more input, it calls in().  When inflateBack() has
223   filled the window with output, or when it completes with data in the
224   window, it calls out() to write out the data.  The application must not
225   change the provided input until in() is called again or inflateBack()
226   returns.  The application must not change the window/output buffer until
227   inflateBack() returns.
228
229   in() and out() are called with a descriptor parameter provided in the
230   inflateBack() call.  This parameter can be a structure that provides the
231   information required to do the read or write, as well as accumulated
232   information on the input and output such as totals and check values.
233
234   in() should return zero on failure.  out() should return non-zero on
235   failure.  If either in() or out() fails, than inflateBack() returns a
236   Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
237   was in() or out() that caused in the error.  Otherwise,  inflateBack()
238   returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
239   error, or Z_MEM_ERROR if it could not allocate memory for the state.
240   inflateBack() can also return Z_STREAM_ERROR if the input parameters
241   are not correct, i.e. strm is Z_NULL or the state was not initialized.
242 */
243int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
244z_streamp strm;
245in_func in;
246void FAR *in_desc;
247out_func out;
248void FAR *out_desc;
249{
250    struct inflate_state FAR *state;
251    unsigned char FAR *next;    /* next input */
252    unsigned char FAR *put;     /* next output */
253    unsigned have, left;        /* available input and output */
254    unsigned long hold;         /* bit buffer */
255    unsigned bits;              /* bits in bit buffer */
256    unsigned copy;              /* number of stored or match bytes to copy */
257    unsigned char FAR *from;    /* where to copy match bytes from */
258    code this;                  /* current decoding table entry */
259    code last;                  /* parent table entry */
260    unsigned len;               /* length to copy for repeats, bits to drop */
261    int ret;                    /* return code */
262    static const unsigned short order[19] = /* permutation of code lengths */
263        {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
264
265    /* Check that the strm exists and that the state was initialized */
266    if (strm == Z_NULL || strm->state == Z_NULL)
267        return Z_STREAM_ERROR;
268    state = (struct inflate_state FAR *)strm->state;
269
270    /* Reset the state */
271    strm->msg = Z_NULL;
272    state->mode = TYPE;
273    state->last = 0;
274    state->whave = 0;
275    next = strm->next_in;
276    have = next != Z_NULL ? strm->avail_in : 0;
277    hold = 0;
278    bits = 0;
279    put = state->window;
280    left = state->wsize;
281
282    /* Inflate until end of block marked as last */
283    for (;;)
284        switch (state->mode) {
285        case TYPE:
286            /* determine and dispatch block type */
287            if (state->last) {
288                BYTEBITS();
289                state->mode = DONE;
290                break;
291            }
292            NEEDBITS(3);
293            state->last = BITS(1);
294            DROPBITS(1);
295            switch (BITS(2)) {
296            case 0:                             /* stored block */
297                Tracev((stderr, "inflate:     stored block%s\n",
298                        state->last ? " (last)" : ""));
299                state->mode = STORED;
300                break;
301            case 1:                             /* fixed block */
302                fixedtables(state);
303                Tracev((stderr, "inflate:     fixed codes block%s\n",
304                        state->last ? " (last)" : ""));
305                state->mode = LEN;              /* decode codes */
306                break;
307            case 2:                             /* dynamic block */
308                Tracev((stderr, "inflate:     dynamic codes block%s\n",
309                        state->last ? " (last)" : ""));
310                state->mode = TABLE;
311                break;
312            case 3:
313                strm->msg = __UNCONST("invalid block type");
314                state->mode = BAD;
315            }
316            DROPBITS(2);
317            break;
318
319        case STORED:
320            /* get and verify stored block length */
321            BYTEBITS();                         /* go to byte boundary */
322            NEEDBITS(32);
323            if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
324                strm->msg = __UNCONST("invalid stored block lengths");
325                state->mode = BAD;
326                break;
327            }
328            state->length = (unsigned)hold & 0xffff;
329            Tracev((stderr, "inflate:       stored length %u\n",
330                    state->length));
331            INITBITS();
332
333            /* copy stored block from input to output */
334            while (state->length != 0) {
335                copy = state->length;
336                PULL();
337                ROOM();
338                if (copy > have) copy = have;
339                if (copy > left) copy = left;
340                zmemcpy(put, next, copy);
341                have -= copy;
342                next += copy;
343                left -= copy;
344                put += copy;
345                state->length -= copy;
346            }
347            Tracev((stderr, "inflate:       stored end\n"));
348            state->mode = TYPE;
349            break;
350
351        case TABLE:
352            /* get dynamic table entries descriptor */
353            NEEDBITS(14);
354            state->nlen = BITS(5) + 257;
355            DROPBITS(5);
356            state->ndist = BITS(5) + 1;
357            DROPBITS(5);
358            state->ncode = BITS(4) + 4;
359            DROPBITS(4);
360#ifndef PKZIP_BUG_WORKAROUND
361            if (state->nlen > 286 || state->ndist > 30) {
362                strm->msg = __UNCONST("too many length or distance symbols");
363                state->mode = BAD;
364                break;
365            }
366#endif
367            Tracev((stderr, "inflate:       table sizes ok\n"));
368
369            /* get code length code lengths (not a typo) */
370            state->have = 0;
371            while (state->have < state->ncode) {
372                NEEDBITS(3);
373                state->lens[order[state->have++]] = (unsigned short)BITS(3);
374                DROPBITS(3);
375            }
376            while (state->have < 19)
377                state->lens[order[state->have++]] = 0;
378            state->next = state->codes;
379            state->lencode = (code const FAR *)(state->next);
380            state->lenbits = 7;
381            ret = inflate_table(CODES, state->lens, 19, &(state->next),
382                                &(state->lenbits), state->work);
383            if (ret) {
384                strm->msg = __UNCONST("invalid code lengths set");
385                state->mode = BAD;
386                break;
387            }
388            Tracev((stderr, "inflate:       code lengths ok\n"));
389
390            /* get length and distance code code lengths */
391            state->have = 0;
392            while (state->have < state->nlen + state->ndist) {
393                for (;;) {
394                    this = state->lencode[BITS(state->lenbits)];
395                    if ((unsigned)(this.bits) <= bits) break;
396                    PULLBYTE();
397                }
398                if (this.val < 16) {
399                    NEEDBITS(this.bits);
400                    DROPBITS(this.bits);
401                    state->lens[state->have++] = this.val;
402                }
403                else {
404                    if (this.val == 16) {
405                        NEEDBITS(this.bits + 2);
406                        DROPBITS(this.bits);
407                        if (state->have == 0) {
408                            strm->msg = __UNCONST("invalid bit length repeat");
409                            state->mode = BAD;
410                            break;
411                        }
412                        len = (unsigned)(state->lens[state->have - 1]);
413                        copy = 3 + BITS(2);
414                        DROPBITS(2);
415                    }
416                    else if (this.val == 17) {
417                        NEEDBITS(this.bits + 3);
418                        DROPBITS(this.bits);
419                        len = 0;
420                        copy = 3 + BITS(3);
421                        DROPBITS(3);
422                    }
423                    else {
424                        NEEDBITS(this.bits + 7);
425                        DROPBITS(this.bits);
426                        len = 0;
427                        copy = 11 + BITS(7);
428                        DROPBITS(7);
429                    }
430                    if (state->have + copy > state->nlen + state->ndist) {
431                        strm->msg = __UNCONST("invalid bit length repeat");
432                        state->mode = BAD;
433                        break;
434                    }
435                    while (copy--)
436                        state->lens[state->have++] = (unsigned short)len;
437                }
438            }
439
440            /* handle error breaks in while */
441            if (state->mode == BAD) break;
442
443            /* build code tables */
444            state->next = state->codes;
445            state->lencode = (code const FAR *)(state->next);
446            state->lenbits = 9;
447            ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
448                                &(state->lenbits), state->work);
449            if (ret) {
450                strm->msg = __UNCONST("invalid literal/lengths set");
451                state->mode = BAD;
452                break;
453            }
454            state->distcode = (code const FAR *)(state->next);
455            state->distbits = 6;
456            ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
457                            &(state->next), &(state->distbits), state->work);
458            if (ret) {
459                strm->msg = __UNCONST("invalid distances set");
460                state->mode = BAD;
461                break;
462            }
463            Tracev((stderr, "inflate:       codes ok\n"));
464            state->mode = LEN;
465
466        case LEN:
467            /* use inflate_fast() if we have enough input and output */
468            if (have >= 6 && left >= 258) {
469                RESTORE();
470                if (state->whave < state->wsize)
471                    state->whave = state->wsize - left;
472                inflate_fast(strm, state->wsize);
473                LOAD();
474                break;
475            }
476
477            /* get a literal, length, or end-of-block code */
478            for (;;) {
479                this = state->lencode[BITS(state->lenbits)];
480                if ((unsigned)(this.bits) <= bits) break;
481                PULLBYTE();
482            }
483            if (this.op && (this.op & 0xf0) == 0) {
484                last = this;
485                for (;;) {
486                    this = state->lencode[last.val +
487                            (BITS(last.bits + last.op) >> last.bits)];
488                    if ((unsigned)(last.bits + this.bits) <= bits) break;
489                    PULLBYTE();
490                }
491                DROPBITS(last.bits);
492            }
493            DROPBITS(this.bits);
494            state->length = (unsigned)this.val;
495
496            /* process literal */
497            if (this.op == 0) {
498                Tracevv((stderr, this.val >= 0x20 && this.val < 0x7f ?
499                        "inflate:         literal '%c'\n" :
500                        "inflate:         literal 0x%02x\n", this.val));
501                ROOM();
502                *put++ = (unsigned char)(state->length);
503                left--;
504                state->mode = LEN;
505                break;
506            }
507
508            /* process end of block */
509            if (this.op & 32) {
510                Tracevv((stderr, "inflate:         end of block\n"));
511                state->mode = TYPE;
512                break;
513            }
514
515            /* invalid code */
516            if (this.op & 64) {
517                strm->msg = __UNCONST("invalid literal/length code");
518                state->mode = BAD;
519                break;
520            }
521
522            /* length code -- get extra bits, if any */
523            state->extra = (unsigned)(this.op) & 15;
524            if (state->extra != 0) {
525                NEEDBITS(state->extra);
526                state->length += BITS(state->extra);
527                DROPBITS(state->extra);
528            }
529            Tracevv((stderr, "inflate:         length %u\n", state->length));
530
531            /* get distance code */
532            for (;;) {
533                this = state->distcode[BITS(state->distbits)];
534                if ((unsigned)(this.bits) <= bits) break;
535                PULLBYTE();
536            }
537            if ((this.op & 0xf0) == 0) {
538                last = this;
539                for (;;) {
540                    this = state->distcode[last.val +
541                            (BITS(last.bits + last.op) >> last.bits)];
542                    if ((unsigned)(last.bits + this.bits) <= bits) break;
543                    PULLBYTE();
544                }
545                DROPBITS(last.bits);
546            }
547            DROPBITS(this.bits);
548            if (this.op & 64) {
549                strm->msg = __UNCONST("invalid distance code");
550                state->mode = BAD;
551                break;
552            }
553            state->offset = (unsigned)this.val;
554
555            /* get distance extra bits, if any */
556            state->extra = (unsigned)(this.op) & 15;
557            if (state->extra != 0) {
558                NEEDBITS(state->extra);
559                state->offset += BITS(state->extra);
560                DROPBITS(state->extra);
561            }
562            if (state->offset > state->wsize - (state->whave < state->wsize ?
563                                                left : 0)) {
564                strm->msg = __UNCONST("invalid distance too far back");
565                state->mode = BAD;
566                break;
567            }
568            Tracevv((stderr, "inflate:         distance %u\n", state->offset));
569
570            /* copy match from window to output */
571            do {
572                ROOM();
573                copy = state->wsize - state->offset;
574                if (copy < left) {
575                    from = put + copy;
576                    copy = left - copy;
577                }
578                else {
579                    from = put - state->offset;
580                    copy = left;
581                }
582                if (copy > state->length) copy = state->length;
583                state->length -= copy;
584                left -= copy;
585                do {
586                    *put++ = *from++;
587                } while (--copy);
588            } while (state->length != 0);
589            break;
590
591        case DONE:
592            /* inflate stream terminated properly -- write leftover output */
593            ret = Z_STREAM_END;
594            if (left < state->wsize) {
595                if (out(out_desc, state->window, state->wsize - left))
596                    ret = Z_BUF_ERROR;
597            }
598            goto inf_leave;
599
600        case BAD:
601            ret = Z_DATA_ERROR;
602            goto inf_leave;
603
604        default:                /* can't happen, but makes compilers happy */
605            ret = Z_STREAM_ERROR;
606            goto inf_leave;
607        }
608
609    /* Return unused input */
610  inf_leave:
611    strm->next_in = next;
612    strm->avail_in = have;
613    return ret;
614}
615
616int ZEXPORT inflateBackEnd(strm)
617z_streamp strm;
618{
619    if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
620        return Z_STREAM_ERROR;
621    ZFREE(strm, strm->state);
622    strm->state = Z_NULL;
623    Tracev((stderr, "inflate: end\n"));
624    return Z_OK;
625}
626