1/*
2 * Copyright (C) 2013 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 *    notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 *    notice, this list of conditions and the following disclaimer in the
11 *    documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "config.h"
27#include "Compression.h"
28
29#include "CheckedArithmetic.h"
30
31#if USE(ZLIB) && !COMPILER(MSVC)
32
33#include <string.h>
34#include <zlib.h>
35
36namespace WTF {
37
38static void* zAlloc(void*, uint32_t count, uint32_t size)
39{
40    CheckedSize allocSize = count;
41    allocSize *= size;
42    if (allocSize.hasOverflowed())
43        return Z_NULL;
44    void* result = 0;
45    if (tryFastMalloc(allocSize.unsafeGet()).getValue(result))
46        return result;
47    return Z_NULL;
48}
49
50static void zFree(void*, void* data)
51{
52    fastFree(data);
53}
54
55std::unique_ptr<GenericCompressedData> GenericCompressedData::create(const uint8_t* data, size_t dataLength)
56{
57    enum { MinimumSize = sizeof(GenericCompressedData) * 8 };
58
59    if (!data || dataLength < MinimumSize)
60        return nullptr;
61
62    z_stream stream;
63    memset(&stream, 0, sizeof(stream));
64    stream.zalloc = zAlloc;
65    stream.zfree = zFree;
66    stream.data_type = Z_BINARY;
67    stream.opaque = Z_NULL;
68    stream.avail_in = dataLength;
69    stream.next_in = const_cast<uint8_t*>(data);
70
71    size_t currentOffset = OBJECT_OFFSETOF(GenericCompressedData, m_data);
72    size_t currentCapacity = fastMallocGoodSize(MinimumSize);
73    Bytef* compressedData = static_cast<Bytef*>(fastMalloc(currentCapacity));
74    memset(compressedData, 0, sizeof(GenericCompressedData));
75    stream.next_out = compressedData + currentOffset;
76    stream.avail_out = currentCapacity - currentOffset;
77
78    deflateInit(&stream, Z_BEST_COMPRESSION);
79
80    while (true) {
81        int deflateResult = deflate(&stream, Z_FINISH);
82        if (deflateResult == Z_OK || !stream.avail_out) {
83            size_t newCapacity = 0;
84            currentCapacity -= stream.avail_out;
85            if (!stream.avail_in)
86                newCapacity = currentCapacity + 8;
87            else {
88                // Determine average capacity
89                size_t compressedContent = stream.next_in - data;
90                double expectedSize = static_cast<double>(dataLength) * compressedContent / currentCapacity;
91
92                // Expand capacity by at least 8 bytes so we're always growing, and to
93                // compensate for any exaggerated ideas of how effectively we'll compress
94                // data in the future.
95                newCapacity = std::max(static_cast<size_t>(expectedSize + 8), currentCapacity + 8);
96            }
97            newCapacity = fastMallocGoodSize(newCapacity);
98            if (newCapacity >= dataLength)
99                goto fail;
100            compressedData = static_cast<Bytef*>(fastRealloc(compressedData, newCapacity));
101            currentOffset = currentCapacity - stream.avail_out;
102            stream.next_out = compressedData + currentOffset;
103            stream.avail_out = newCapacity - currentCapacity;
104            currentCapacity = newCapacity;
105            continue;
106        }
107
108        if (deflateResult == Z_STREAM_END) {
109            ASSERT(!stream.avail_in);
110            break;
111        }
112
113        ASSERT_NOT_REACHED();
114    fail:
115        deflateEnd(&stream);
116        fastFree(compressedData);
117        return nullptr;
118    }
119    deflateEnd(&stream);
120    static int64_t totalCompressed = 0;
121    static int64_t totalInput = 0;
122
123    totalCompressed += currentCapacity;
124    totalInput += dataLength;
125    return std::unique_ptr<GenericCompressedData>(new (compressedData) GenericCompressedData(dataLength, stream.total_out));
126}
127
128bool GenericCompressedData::decompress(uint8_t* destination, size_t bufferSize, size_t* decompressedByteCount)
129{
130    if (decompressedByteCount)
131        *decompressedByteCount = 0;
132    z_stream stream;
133    memset(&stream, 0, sizeof(stream));
134    stream.zalloc = zAlloc;
135    stream.zfree = zFree;
136    stream.data_type = Z_BINARY;
137    stream.opaque = Z_NULL;
138    stream.next_out = destination;
139    stream.avail_out = bufferSize;
140    stream.next_in = m_data;
141    stream.avail_in = compressedSize();
142    if (inflateInit(&stream) != Z_OK) {
143        ASSERT_NOT_REACHED();
144        return false;
145    }
146
147    int inflateResult = inflate(&stream, Z_FINISH);
148    inflateEnd(&stream);
149
150    ASSERT(stream.total_out <= bufferSize);
151    if (decompressedByteCount)
152        *decompressedByteCount = stream.total_out;
153
154    if (inflateResult != Z_STREAM_END) {
155        ASSERT_NOT_REACHED();
156        return false;
157    }
158
159    return true;
160}
161
162}
163
164#else
165
166namespace WTF {
167std::unique_ptr<GenericCompressedData> GenericCompressedData::create(const uint8_t*, size_t)
168{
169    return nullptr;
170}
171
172bool GenericCompressedData::decompress(uint8_t*, size_t, size_t*)
173{
174    return false;
175}
176}
177
178#endif
179