1// zinflate.cpp - written and placed in the public domain by Wei Dai
2
3// This is a complete reimplementation of the DEFLATE decompression algorithm.
4// It should not be affected by any security vulnerabilities in the zlib
5// compression library. In particular it is not affected by the double free bug
6// (http://www.kb.cert.org/vuls/id/368819).
7
8#include "pch.h"
9#include "zinflate.h"
10
11NAMESPACE_BEGIN(CryptoPP)
12
13struct CodeLessThan
14{
15	inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
16		{return lhs < rhs.code;}
17	// needed for MSVC .NET 2005
18	inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
19		{return lhs.code < rhs.code;}
20};
21
22inline bool LowFirstBitReader::FillBuffer(unsigned int length)
23{
24	while (m_bitsBuffered < length)
25	{
26		byte b;
27		if (!m_store.Get(b))
28			return false;
29		m_buffer |= (unsigned long)b << m_bitsBuffered;
30		m_bitsBuffered += 8;
31	}
32	assert(m_bitsBuffered <= sizeof(unsigned long)*8);
33	return true;
34}
35
36inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
37{
38	bool result = FillBuffer(length);
39	assert(result);
40	return m_buffer & (((unsigned long)1 << length) - 1);
41}
42
43inline void LowFirstBitReader::SkipBits(unsigned int length)
44{
45	assert(m_bitsBuffered >= length);
46	m_buffer >>= length;
47	m_bitsBuffered -= length;
48}
49
50inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
51{
52	unsigned long result = PeekBits(length);
53	SkipBits(length);
54	return result;
55}
56
57inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
58{
59	return code << (MAX_CODE_BITS - codeBits);
60}
61
62void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
63{
64	// the Huffman codes are represented in 3 ways in this code:
65	//
66	// 1. most significant code bit (i.e. top of code tree) in the least significant bit position
67	// 2. most significant code bit (i.e. top of code tree) in the most significant bit position
68	// 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
69	//    where n is the maximum code length for this code tree
70	//
71	// (1) is the way the codes come in from the deflate stream
72	// (2) is used to sort codes so they can be binary searched
73	// (3) is used in this function to compute codes from code lengths
74	//
75	// a code in representation (2) is called "normalized" here
76	// The BitReverse() function is used to convert between (1) and (2)
77	// The NormalizeCode() function is used to convert from (3) to (2)
78
79	if (nCodes == 0)
80		throw Err("null code");
81
82	m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
83
84	if (m_maxCodeBits > MAX_CODE_BITS)
85		throw Err("code length exceeds maximum");
86
87	if (m_maxCodeBits == 0)
88		throw Err("null code");
89
90	// count number of codes of each length
91	SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
92	std::fill(blCount.begin(), blCount.end(), 0);
93	unsigned int i;
94	for (i=0; i<nCodes; i++)
95		blCount[codeBits[i]]++;
96
97	// compute the starting code of each length
98	code_t code = 0;
99	SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
100	nextCode[1] = 0;
101	for (i=2; i<=m_maxCodeBits; i++)
102	{
103		// compute this while checking for overflow: code = (code + blCount[i-1]) << 1
104		if (code > code + blCount[i-1])
105			throw Err("codes oversubscribed");
106		code += blCount[i-1];
107		if (code > (code << 1))
108			throw Err("codes oversubscribed");
109		code <<= 1;
110		nextCode[i] = code;
111	}
112
113	if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
114		throw Err("codes oversubscribed");
115	else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
116		throw Err("codes incomplete");
117
118	// compute a vector of <code, length, value> triples sorted by code
119	m_codeToValue.resize(nCodes - blCount[0]);
120	unsigned int j=0;
121	for (i=0; i<nCodes; i++)
122	{
123		unsigned int len = codeBits[i];
124		if (len != 0)
125		{
126			code = NormalizeCode(nextCode[len]++, len);
127			m_codeToValue[j].code = code;
128			m_codeToValue[j].len = len;
129			m_codeToValue[j].value = i;
130			j++;
131		}
132	}
133	std::sort(m_codeToValue.begin(), m_codeToValue.end());
134
135	// initialize the decoding cache
136	m_cacheBits = STDMIN(9U, m_maxCodeBits);
137	m_cacheMask = (1 << m_cacheBits) - 1;
138	m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
139	assert(m_normalizedCacheMask == BitReverse(m_cacheMask));
140
141	if (m_cache.size() != size_t(1) << m_cacheBits)
142		m_cache.resize(1 << m_cacheBits);
143
144	for (i=0; i<m_cache.size(); i++)
145		m_cache[i].type = 0;
146}
147
148void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
149{
150	normalizedCode &= m_normalizedCacheMask;
151	const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
152	if (codeInfo.len <= m_cacheBits)
153	{
154		entry.type = 1;
155		entry.value = codeInfo.value;
156		entry.len = codeInfo.len;
157	}
158	else
159	{
160		entry.begin = &codeInfo;
161		const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
162		if (codeInfo.len == last->len)
163		{
164			entry.type = 2;
165			entry.len = codeInfo.len;
166		}
167		else
168		{
169			entry.type = 3;
170			entry.end = last+1;
171		}
172	}
173}
174
175inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
176{
177	assert(m_codeToValue.size() > 0);
178	LookupEntry &entry = m_cache[code & m_cacheMask];
179
180	code_t normalizedCode;
181	if (entry.type != 1)
182		normalizedCode = BitReverse(code);
183
184	if (entry.type == 0)
185		FillCacheEntry(entry, normalizedCode);
186
187	if (entry.type == 1)
188	{
189		value = entry.value;
190		return entry.len;
191	}
192	else
193	{
194		const CodeInfo &codeInfo = (entry.type == 2)
195			? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
196			: *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
197		value = codeInfo.value;
198		return codeInfo.len;
199	}
200}
201
202bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
203{
204	reader.FillBuffer(m_maxCodeBits);
205	unsigned int codeBits = Decode(reader.PeekBuffer(), value);
206	if (codeBits > reader.BitsBuffered())
207		return false;
208	reader.SkipBits(codeBits);
209	return true;
210}
211
212// *************************************************************
213
214Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
215	: AutoSignaling<Filter>(propagation)
216	, m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue)
217{
218	Detach(attachment);
219}
220
221void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
222{
223	m_state = PRE_STREAM;
224	parameters.GetValue("Repeat", m_repeat);
225	m_inQueue.Clear();
226	m_reader.SkipBits(m_reader.BitsBuffered());
227}
228
229void Inflator::OutputByte(byte b)
230{
231	m_window[m_current++] = b;
232	if (m_current == m_window.size())
233	{
234		ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
235		m_lastFlush = 0;
236		m_current = 0;
237		m_wrappedAround = true;
238	}
239}
240
241void Inflator::OutputString(const byte *string, size_t length)
242{
243	while (length)
244	{
245		size_t len = UnsignedMin(length, m_window.size() - m_current);
246		memcpy(m_window + m_current, string, len);
247		m_current += len;
248		if (m_current == m_window.size())
249		{
250			ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
251			m_lastFlush = 0;
252			m_current = 0;
253			m_wrappedAround = true;
254		}
255		string += len;
256		length -= len;
257	}
258}
259
260void Inflator::OutputPast(unsigned int length, unsigned int distance)
261{
262	size_t start;
263	if (distance <= m_current)
264		start = m_current - distance;
265	else if (m_wrappedAround && distance <= m_window.size())
266		start = m_current + m_window.size() - distance;
267	else
268		throw BadBlockErr();
269
270	if (start + length > m_window.size())
271	{
272		for (; start < m_window.size(); start++, length--)
273			OutputByte(m_window[start]);
274		start = 0;
275	}
276
277	if (start + length > m_current || m_current + length >= m_window.size())
278	{
279		while (length--)
280			OutputByte(m_window[start++]);
281	}
282	else
283	{
284		memcpy(m_window + m_current, m_window + start, length);
285		m_current += length;
286	}
287}
288
289size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
290{
291	if (!blocking)
292		throw BlockingInputOnly("Inflator");
293
294	LazyPutter lp(m_inQueue, inString, length);
295	ProcessInput(messageEnd != 0);
296
297	if (messageEnd)
298		if (!(m_state == PRE_STREAM || m_state == AFTER_END))
299			throw UnexpectedEndErr();
300
301	Output(0, NULL, 0, messageEnd, blocking);
302	return 0;
303}
304
305bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
306{
307	if (!blocking)
308		throw BlockingInputOnly("Inflator");
309
310	if (hardFlush)
311		ProcessInput(true);
312	FlushOutput();
313
314	return false;
315}
316
317void Inflator::ProcessInput(bool flush)
318{
319	while (true)
320	{
321		switch (m_state)
322		{
323		case PRE_STREAM:
324			if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
325				return;
326			ProcessPrestreamHeader();
327			m_state = WAIT_HEADER;
328			m_wrappedAround = false;
329			m_current = 0;
330			m_lastFlush = 0;
331			m_window.New(1 << GetLog2WindowSize());
332			break;
333		case WAIT_HEADER:
334			{
335			// maximum number of bytes before actual compressed data starts
336			const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
337			if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
338				return;
339			DecodeHeader();
340			break;
341			}
342		case DECODING_BODY:
343			if (!DecodeBody())
344				return;
345			break;
346		case POST_STREAM:
347			if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
348				return;
349			ProcessPoststreamTail();
350			m_state = m_repeat ? PRE_STREAM : AFTER_END;
351			Output(0, NULL, 0, GetAutoSignalPropagation(), true);	// TODO: non-blocking
352			if (m_inQueue.IsEmpty())
353				return;
354			break;
355		case AFTER_END:
356			m_inQueue.TransferTo(*AttachedTransformation());
357			return;
358		}
359	}
360}
361
362void Inflator::DecodeHeader()
363{
364	if (!m_reader.FillBuffer(3))
365		throw UnexpectedEndErr();
366	m_eof = m_reader.GetBits(1) != 0;
367	m_blockType = (byte)m_reader.GetBits(2);
368	switch (m_blockType)
369	{
370	case 0:	// stored
371		{
372		m_reader.SkipBits(m_reader.BitsBuffered() % 8);
373		if (!m_reader.FillBuffer(32))
374			throw UnexpectedEndErr();
375		m_storedLen = (word16)m_reader.GetBits(16);
376		word16 nlen = (word16)m_reader.GetBits(16);
377		if (nlen != (word16)~m_storedLen)
378			throw BadBlockErr();
379		break;
380		}
381	case 1:	// fixed codes
382		m_nextDecode = LITERAL;
383		break;
384	case 2:	// dynamic codes
385		{
386		if (!m_reader.FillBuffer(5+5+4))
387			throw UnexpectedEndErr();
388		unsigned int hlit = m_reader.GetBits(5);
389		unsigned int hdist = m_reader.GetBits(5);
390		unsigned int hclen = m_reader.GetBits(4);
391
392		FixedSizeSecBlock<unsigned int, 286+32> codeLengths;
393		unsigned int i;
394		static const unsigned int border[] = {    // Order of the bit length code lengths
395			16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
396		std::fill(codeLengths.begin(), codeLengths+19, 0);
397		for (i=0; i<hclen+4; i++)
398			codeLengths[border[i]] = m_reader.GetBits(3);
399
400		try
401		{
402			HuffmanDecoder codeLengthDecoder(codeLengths, 19);
403			for (i = 0; i < hlit+257+hdist+1; )
404			{
405				unsigned int k, count, repeater;
406				bool result = codeLengthDecoder.Decode(m_reader, k);
407				if (!result)
408					throw UnexpectedEndErr();
409				if (k <= 15)
410				{
411					count = 1;
412					repeater = k;
413				}
414				else switch (k)
415				{
416				case 16:
417					if (!m_reader.FillBuffer(2))
418						throw UnexpectedEndErr();
419					count = 3 + m_reader.GetBits(2);
420					if (i == 0)
421						throw BadBlockErr();
422					repeater = codeLengths[i-1];
423					break;
424				case 17:
425					if (!m_reader.FillBuffer(3))
426						throw UnexpectedEndErr();
427					count = 3 + m_reader.GetBits(3);
428					repeater = 0;
429					break;
430				case 18:
431					if (!m_reader.FillBuffer(7))
432						throw UnexpectedEndErr();
433					count = 11 + m_reader.GetBits(7);
434					repeater = 0;
435					break;
436				}
437				if (i + count > hlit+257+hdist+1)
438					throw BadBlockErr();
439				std::fill(codeLengths + i, codeLengths + i + count, repeater);
440				i += count;
441			}
442			m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257);
443			if (hdist == 0 && codeLengths[hlit+257] == 0)
444			{
445				if (hlit != 0)	// a single zero distance code length means all literals
446					throw BadBlockErr();
447			}
448			else
449				m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
450			m_nextDecode = LITERAL;
451		}
452		catch (HuffmanDecoder::Err &)
453		{
454			throw BadBlockErr();
455		}
456		break;
457		}
458	default:
459		throw BadBlockErr();	// reserved block type
460	}
461	m_state = DECODING_BODY;
462}
463
464bool Inflator::DecodeBody()
465{
466	bool blockEnd = false;
467	switch (m_blockType)
468	{
469	case 0:	// stored
470		assert(m_reader.BitsBuffered() == 0);
471		while (!m_inQueue.IsEmpty() && !blockEnd)
472		{
473			size_t size;
474			const byte *block = m_inQueue.Spy(size);
475			size = UnsignedMin(m_storedLen, size);
476			OutputString(block, size);
477			m_inQueue.Skip(size);
478			m_storedLen -= (word16)size;
479			if (m_storedLen == 0)
480				blockEnd = true;
481		}
482		break;
483	case 1:	// fixed codes
484	case 2:	// dynamic codes
485		static const unsigned int lengthStarts[] = {
486			3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
487			35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
488		static const unsigned int lengthExtraBits[] = {
489			0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
490			3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
491		static const unsigned int distanceStarts[] = {
492			1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
493			257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
494			8193, 12289, 16385, 24577};
495		static const unsigned int distanceExtraBits[] = {
496			0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
497			7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
498			12, 12, 13, 13};
499
500		const HuffmanDecoder& literalDecoder = GetLiteralDecoder();
501		const HuffmanDecoder& distanceDecoder = GetDistanceDecoder();
502
503		switch (m_nextDecode)
504		{
505		case LITERAL:
506			while (true)
507			{
508				if (!literalDecoder.Decode(m_reader, m_literal))
509				{
510					m_nextDecode = LITERAL;
511					break;
512				}
513				if (m_literal < 256)
514					OutputByte((byte)m_literal);
515				else if (m_literal == 256)	// end of block
516				{
517					blockEnd = true;
518					break;
519				}
520				else
521				{
522					if (m_literal > 285)
523						throw BadBlockErr();
524					unsigned int bits;
525		case LENGTH_BITS:
526					bits = lengthExtraBits[m_literal-257];
527					if (!m_reader.FillBuffer(bits))
528					{
529						m_nextDecode = LENGTH_BITS;
530						break;
531					}
532					m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
533		case DISTANCE:
534					if (!distanceDecoder.Decode(m_reader, m_distance))
535					{
536						m_nextDecode = DISTANCE;
537						break;
538					}
539		case DISTANCE_BITS:
540					bits = distanceExtraBits[m_distance];
541					if (!m_reader.FillBuffer(bits))
542					{
543						m_nextDecode = DISTANCE_BITS;
544						break;
545					}
546					m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
547					OutputPast(m_literal, m_distance);
548				}
549			}
550		}
551	}
552	if (blockEnd)
553	{
554		if (m_eof)
555		{
556			FlushOutput();
557			m_reader.SkipBits(m_reader.BitsBuffered()%8);
558			if (m_reader.BitsBuffered())
559			{
560				// undo too much lookahead
561				SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
562				for (unsigned int i=0; i<buffer.size(); i++)
563					buffer[i] = (byte)m_reader.GetBits(8);
564				m_inQueue.Unget(buffer, buffer.size());
565			}
566			m_state = POST_STREAM;
567		}
568		else
569			m_state = WAIT_HEADER;
570	}
571	return blockEnd;
572}
573
574void Inflator::FlushOutput()
575{
576	if (m_state != PRE_STREAM)
577	{
578		assert(m_current >= m_lastFlush);
579		ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
580		m_lastFlush = m_current;
581	}
582}
583
584struct NewFixedLiteralDecoder
585{
586	HuffmanDecoder * operator()() const
587	{
588		unsigned int codeLengths[288];
589		std::fill(codeLengths + 0, codeLengths + 144, 8);
590		std::fill(codeLengths + 144, codeLengths + 256, 9);
591		std::fill(codeLengths + 256, codeLengths + 280, 7);
592		std::fill(codeLengths + 280, codeLengths + 288, 8);
593		std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
594		pDecoder->Initialize(codeLengths, 288);
595		return pDecoder.release();
596	}
597};
598
599struct NewFixedDistanceDecoder
600{
601	HuffmanDecoder * operator()() const
602	{
603		unsigned int codeLengths[32];
604		std::fill(codeLengths + 0, codeLengths + 32, 5);
605		std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
606		pDecoder->Initialize(codeLengths, 32);
607		return pDecoder.release();
608	}
609};
610
611const HuffmanDecoder& Inflator::GetLiteralDecoder() const
612{
613	return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder;
614}
615
616const HuffmanDecoder& Inflator::GetDistanceDecoder() const
617{
618	return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder;
619}
620
621NAMESPACE_END
622