1// zinflate.cpp - written and placed in the public domain by Wei Dai 2 3// This is a complete reimplementation of the DEFLATE decompression algorithm. 4// It should not be affected by any security vulnerabilities in the zlib 5// compression library. In particular it is not affected by the double free bug 6// (http://www.kb.cert.org/vuls/id/368819). 7 8#include "pch.h" 9#include "zinflate.h" 10 11NAMESPACE_BEGIN(CryptoPP) 12 13struct CodeLessThan 14{ 15 inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) 16 {return lhs < rhs.code;} 17 // needed for MSVC .NET 2005 18 inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) 19 {return lhs.code < rhs.code;} 20}; 21 22inline bool LowFirstBitReader::FillBuffer(unsigned int length) 23{ 24 while (m_bitsBuffered < length) 25 { 26 byte b; 27 if (!m_store.Get(b)) 28 return false; 29 m_buffer |= (unsigned long)b << m_bitsBuffered; 30 m_bitsBuffered += 8; 31 } 32 assert(m_bitsBuffered <= sizeof(unsigned long)*8); 33 return true; 34} 35 36inline unsigned long LowFirstBitReader::PeekBits(unsigned int length) 37{ 38 bool result = FillBuffer(length); 39 assert(result); 40 return m_buffer & (((unsigned long)1 << length) - 1); 41} 42 43inline void LowFirstBitReader::SkipBits(unsigned int length) 44{ 45 assert(m_bitsBuffered >= length); 46 m_buffer >>= length; 47 m_bitsBuffered -= length; 48} 49 50inline unsigned long LowFirstBitReader::GetBits(unsigned int length) 51{ 52 unsigned long result = PeekBits(length); 53 SkipBits(length); 54 return result; 55} 56 57inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits) 58{ 59 return code << (MAX_CODE_BITS - codeBits); 60} 61 62void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes) 63{ 64 // the Huffman codes are represented in 3 ways in this code: 65 // 66 // 1. most significant code bit (i.e. top of code tree) in the least significant bit position 67 // 2. most significant code bit (i.e. top of code tree) in the most significant bit position 68 // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position, 69 // where n is the maximum code length for this code tree 70 // 71 // (1) is the way the codes come in from the deflate stream 72 // (2) is used to sort codes so they can be binary searched 73 // (3) is used in this function to compute codes from code lengths 74 // 75 // a code in representation (2) is called "normalized" here 76 // The BitReverse() function is used to convert between (1) and (2) 77 // The NormalizeCode() function is used to convert from (3) to (2) 78 79 if (nCodes == 0) 80 throw Err("null code"); 81 82 m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes); 83 84 if (m_maxCodeBits > MAX_CODE_BITS) 85 throw Err("code length exceeds maximum"); 86 87 if (m_maxCodeBits == 0) 88 throw Err("null code"); 89 90 // count number of codes of each length 91 SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1); 92 std::fill(blCount.begin(), blCount.end(), 0); 93 unsigned int i; 94 for (i=0; i<nCodes; i++) 95 blCount[codeBits[i]]++; 96 97 // compute the starting code of each length 98 code_t code = 0; 99 SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1); 100 nextCode[1] = 0; 101 for (i=2; i<=m_maxCodeBits; i++) 102 { 103 // compute this while checking for overflow: code = (code + blCount[i-1]) << 1 104 if (code > code + blCount[i-1]) 105 throw Err("codes oversubscribed"); 106 code += blCount[i-1]; 107 if (code > (code << 1)) 108 throw Err("codes oversubscribed"); 109 code <<= 1; 110 nextCode[i] = code; 111 } 112 113 if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 114 throw Err("codes oversubscribed"); 115 else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 116 throw Err("codes incomplete"); 117 118 // compute a vector of <code, length, value> triples sorted by code 119 m_codeToValue.resize(nCodes - blCount[0]); 120 unsigned int j=0; 121 for (i=0; i<nCodes; i++) 122 { 123 unsigned int len = codeBits[i]; 124 if (len != 0) 125 { 126 code = NormalizeCode(nextCode[len]++, len); 127 m_codeToValue[j].code = code; 128 m_codeToValue[j].len = len; 129 m_codeToValue[j].value = i; 130 j++; 131 } 132 } 133 std::sort(m_codeToValue.begin(), m_codeToValue.end()); 134 135 // initialize the decoding cache 136 m_cacheBits = STDMIN(9U, m_maxCodeBits); 137 m_cacheMask = (1 << m_cacheBits) - 1; 138 m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits); 139 assert(m_normalizedCacheMask == BitReverse(m_cacheMask)); 140 141 if (m_cache.size() != size_t(1) << m_cacheBits) 142 m_cache.resize(1 << m_cacheBits); 143 144 for (i=0; i<m_cache.size(); i++) 145 m_cache[i].type = 0; 146} 147 148void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const 149{ 150 normalizedCode &= m_normalizedCacheMask; 151 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1); 152 if (codeInfo.len <= m_cacheBits) 153 { 154 entry.type = 1; 155 entry.value = codeInfo.value; 156 entry.len = codeInfo.len; 157 } 158 else 159 { 160 entry.begin = &codeInfo; 161 const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1); 162 if (codeInfo.len == last->len) 163 { 164 entry.type = 2; 165 entry.len = codeInfo.len; 166 } 167 else 168 { 169 entry.type = 3; 170 entry.end = last+1; 171 } 172 } 173} 174 175inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const 176{ 177 assert(m_codeToValue.size() > 0); 178 LookupEntry &entry = m_cache[code & m_cacheMask]; 179 180 code_t normalizedCode; 181 if (entry.type != 1) 182 normalizedCode = BitReverse(code); 183 184 if (entry.type == 0) 185 FillCacheEntry(entry, normalizedCode); 186 187 if (entry.type == 1) 188 { 189 value = entry.value; 190 return entry.len; 191 } 192 else 193 { 194 const CodeInfo &codeInfo = (entry.type == 2) 195 ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))] 196 : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1); 197 value = codeInfo.value; 198 return codeInfo.len; 199 } 200} 201 202bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const 203{ 204 reader.FillBuffer(m_maxCodeBits); 205 unsigned int codeBits = Decode(reader.PeekBuffer(), value); 206 if (codeBits > reader.BitsBuffered()) 207 return false; 208 reader.SkipBits(codeBits); 209 return true; 210} 211 212// ************************************************************* 213 214Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation) 215 : AutoSignaling<Filter>(propagation) 216 , m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue) 217{ 218 Detach(attachment); 219} 220 221void Inflator::IsolatedInitialize(const NameValuePairs ¶meters) 222{ 223 m_state = PRE_STREAM; 224 parameters.GetValue("Repeat", m_repeat); 225 m_inQueue.Clear(); 226 m_reader.SkipBits(m_reader.BitsBuffered()); 227} 228 229void Inflator::OutputByte(byte b) 230{ 231 m_window[m_current++] = b; 232 if (m_current == m_window.size()) 233 { 234 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); 235 m_lastFlush = 0; 236 m_current = 0; 237 m_wrappedAround = true; 238 } 239} 240 241void Inflator::OutputString(const byte *string, size_t length) 242{ 243 while (length) 244 { 245 size_t len = UnsignedMin(length, m_window.size() - m_current); 246 memcpy(m_window + m_current, string, len); 247 m_current += len; 248 if (m_current == m_window.size()) 249 { 250 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); 251 m_lastFlush = 0; 252 m_current = 0; 253 m_wrappedAround = true; 254 } 255 string += len; 256 length -= len; 257 } 258} 259 260void Inflator::OutputPast(unsigned int length, unsigned int distance) 261{ 262 size_t start; 263 if (distance <= m_current) 264 start = m_current - distance; 265 else if (m_wrappedAround && distance <= m_window.size()) 266 start = m_current + m_window.size() - distance; 267 else 268 throw BadBlockErr(); 269 270 if (start + length > m_window.size()) 271 { 272 for (; start < m_window.size(); start++, length--) 273 OutputByte(m_window[start]); 274 start = 0; 275 } 276 277 if (start + length > m_current || m_current + length >= m_window.size()) 278 { 279 while (length--) 280 OutputByte(m_window[start++]); 281 } 282 else 283 { 284 memcpy(m_window + m_current, m_window + start, length); 285 m_current += length; 286 } 287} 288 289size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking) 290{ 291 if (!blocking) 292 throw BlockingInputOnly("Inflator"); 293 294 LazyPutter lp(m_inQueue, inString, length); 295 ProcessInput(messageEnd != 0); 296 297 if (messageEnd) 298 if (!(m_state == PRE_STREAM || m_state == AFTER_END)) 299 throw UnexpectedEndErr(); 300 301 Output(0, NULL, 0, messageEnd, blocking); 302 return 0; 303} 304 305bool Inflator::IsolatedFlush(bool hardFlush, bool blocking) 306{ 307 if (!blocking) 308 throw BlockingInputOnly("Inflator"); 309 310 if (hardFlush) 311 ProcessInput(true); 312 FlushOutput(); 313 314 return false; 315} 316 317void Inflator::ProcessInput(bool flush) 318{ 319 while (true) 320 { 321 switch (m_state) 322 { 323 case PRE_STREAM: 324 if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize()) 325 return; 326 ProcessPrestreamHeader(); 327 m_state = WAIT_HEADER; 328 m_wrappedAround = false; 329 m_current = 0; 330 m_lastFlush = 0; 331 m_window.New(1 << GetLog2WindowSize()); 332 break; 333 case WAIT_HEADER: 334 { 335 // maximum number of bytes before actual compressed data starts 336 const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15); 337 if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE)) 338 return; 339 DecodeHeader(); 340 break; 341 } 342 case DECODING_BODY: 343 if (!DecodeBody()) 344 return; 345 break; 346 case POST_STREAM: 347 if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize()) 348 return; 349 ProcessPoststreamTail(); 350 m_state = m_repeat ? PRE_STREAM : AFTER_END; 351 Output(0, NULL, 0, GetAutoSignalPropagation(), true); // TODO: non-blocking 352 if (m_inQueue.IsEmpty()) 353 return; 354 break; 355 case AFTER_END: 356 m_inQueue.TransferTo(*AttachedTransformation()); 357 return; 358 } 359 } 360} 361 362void Inflator::DecodeHeader() 363{ 364 if (!m_reader.FillBuffer(3)) 365 throw UnexpectedEndErr(); 366 m_eof = m_reader.GetBits(1) != 0; 367 m_blockType = (byte)m_reader.GetBits(2); 368 switch (m_blockType) 369 { 370 case 0: // stored 371 { 372 m_reader.SkipBits(m_reader.BitsBuffered() % 8); 373 if (!m_reader.FillBuffer(32)) 374 throw UnexpectedEndErr(); 375 m_storedLen = (word16)m_reader.GetBits(16); 376 word16 nlen = (word16)m_reader.GetBits(16); 377 if (nlen != (word16)~m_storedLen) 378 throw BadBlockErr(); 379 break; 380 } 381 case 1: // fixed codes 382 m_nextDecode = LITERAL; 383 break; 384 case 2: // dynamic codes 385 { 386 if (!m_reader.FillBuffer(5+5+4)) 387 throw UnexpectedEndErr(); 388 unsigned int hlit = m_reader.GetBits(5); 389 unsigned int hdist = m_reader.GetBits(5); 390 unsigned int hclen = m_reader.GetBits(4); 391 392 FixedSizeSecBlock<unsigned int, 286+32> codeLengths; 393 unsigned int i; 394 static const unsigned int border[] = { // Order of the bit length code lengths 395 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; 396 std::fill(codeLengths.begin(), codeLengths+19, 0); 397 for (i=0; i<hclen+4; i++) 398 codeLengths[border[i]] = m_reader.GetBits(3); 399 400 try 401 { 402 HuffmanDecoder codeLengthDecoder(codeLengths, 19); 403 for (i = 0; i < hlit+257+hdist+1; ) 404 { 405 unsigned int k, count, repeater; 406 bool result = codeLengthDecoder.Decode(m_reader, k); 407 if (!result) 408 throw UnexpectedEndErr(); 409 if (k <= 15) 410 { 411 count = 1; 412 repeater = k; 413 } 414 else switch (k) 415 { 416 case 16: 417 if (!m_reader.FillBuffer(2)) 418 throw UnexpectedEndErr(); 419 count = 3 + m_reader.GetBits(2); 420 if (i == 0) 421 throw BadBlockErr(); 422 repeater = codeLengths[i-1]; 423 break; 424 case 17: 425 if (!m_reader.FillBuffer(3)) 426 throw UnexpectedEndErr(); 427 count = 3 + m_reader.GetBits(3); 428 repeater = 0; 429 break; 430 case 18: 431 if (!m_reader.FillBuffer(7)) 432 throw UnexpectedEndErr(); 433 count = 11 + m_reader.GetBits(7); 434 repeater = 0; 435 break; 436 } 437 if (i + count > hlit+257+hdist+1) 438 throw BadBlockErr(); 439 std::fill(codeLengths + i, codeLengths + i + count, repeater); 440 i += count; 441 } 442 m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257); 443 if (hdist == 0 && codeLengths[hlit+257] == 0) 444 { 445 if (hlit != 0) // a single zero distance code length means all literals 446 throw BadBlockErr(); 447 } 448 else 449 m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1); 450 m_nextDecode = LITERAL; 451 } 452 catch (HuffmanDecoder::Err &) 453 { 454 throw BadBlockErr(); 455 } 456 break; 457 } 458 default: 459 throw BadBlockErr(); // reserved block type 460 } 461 m_state = DECODING_BODY; 462} 463 464bool Inflator::DecodeBody() 465{ 466 bool blockEnd = false; 467 switch (m_blockType) 468 { 469 case 0: // stored 470 assert(m_reader.BitsBuffered() == 0); 471 while (!m_inQueue.IsEmpty() && !blockEnd) 472 { 473 size_t size; 474 const byte *block = m_inQueue.Spy(size); 475 size = UnsignedMin(m_storedLen, size); 476 OutputString(block, size); 477 m_inQueue.Skip(size); 478 m_storedLen -= (word16)size; 479 if (m_storedLen == 0) 480 blockEnd = true; 481 } 482 break; 483 case 1: // fixed codes 484 case 2: // dynamic codes 485 static const unsigned int lengthStarts[] = { 486 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 487 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258}; 488 static const unsigned int lengthExtraBits[] = { 489 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 490 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0}; 491 static const unsigned int distanceStarts[] = { 492 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 493 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 494 8193, 12289, 16385, 24577}; 495 static const unsigned int distanceExtraBits[] = { 496 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 497 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 498 12, 12, 13, 13}; 499 500 const HuffmanDecoder& literalDecoder = GetLiteralDecoder(); 501 const HuffmanDecoder& distanceDecoder = GetDistanceDecoder(); 502 503 switch (m_nextDecode) 504 { 505 case LITERAL: 506 while (true) 507 { 508 if (!literalDecoder.Decode(m_reader, m_literal)) 509 { 510 m_nextDecode = LITERAL; 511 break; 512 } 513 if (m_literal < 256) 514 OutputByte((byte)m_literal); 515 else if (m_literal == 256) // end of block 516 { 517 blockEnd = true; 518 break; 519 } 520 else 521 { 522 if (m_literal > 285) 523 throw BadBlockErr(); 524 unsigned int bits; 525 case LENGTH_BITS: 526 bits = lengthExtraBits[m_literal-257]; 527 if (!m_reader.FillBuffer(bits)) 528 { 529 m_nextDecode = LENGTH_BITS; 530 break; 531 } 532 m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257]; 533 case DISTANCE: 534 if (!distanceDecoder.Decode(m_reader, m_distance)) 535 { 536 m_nextDecode = DISTANCE; 537 break; 538 } 539 case DISTANCE_BITS: 540 bits = distanceExtraBits[m_distance]; 541 if (!m_reader.FillBuffer(bits)) 542 { 543 m_nextDecode = DISTANCE_BITS; 544 break; 545 } 546 m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance]; 547 OutputPast(m_literal, m_distance); 548 } 549 } 550 } 551 } 552 if (blockEnd) 553 { 554 if (m_eof) 555 { 556 FlushOutput(); 557 m_reader.SkipBits(m_reader.BitsBuffered()%8); 558 if (m_reader.BitsBuffered()) 559 { 560 // undo too much lookahead 561 SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8); 562 for (unsigned int i=0; i<buffer.size(); i++) 563 buffer[i] = (byte)m_reader.GetBits(8); 564 m_inQueue.Unget(buffer, buffer.size()); 565 } 566 m_state = POST_STREAM; 567 } 568 else 569 m_state = WAIT_HEADER; 570 } 571 return blockEnd; 572} 573 574void Inflator::FlushOutput() 575{ 576 if (m_state != PRE_STREAM) 577 { 578 assert(m_current >= m_lastFlush); 579 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush); 580 m_lastFlush = m_current; 581 } 582} 583 584struct NewFixedLiteralDecoder 585{ 586 HuffmanDecoder * operator()() const 587 { 588 unsigned int codeLengths[288]; 589 std::fill(codeLengths + 0, codeLengths + 144, 8); 590 std::fill(codeLengths + 144, codeLengths + 256, 9); 591 std::fill(codeLengths + 256, codeLengths + 280, 7); 592 std::fill(codeLengths + 280, codeLengths + 288, 8); 593 std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder); 594 pDecoder->Initialize(codeLengths, 288); 595 return pDecoder.release(); 596 } 597}; 598 599struct NewFixedDistanceDecoder 600{ 601 HuffmanDecoder * operator()() const 602 { 603 unsigned int codeLengths[32]; 604 std::fill(codeLengths + 0, codeLengths + 32, 5); 605 std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder); 606 pDecoder->Initialize(codeLengths, 32); 607 return pDecoder.release(); 608 } 609}; 610 611const HuffmanDecoder& Inflator::GetLiteralDecoder() const 612{ 613 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder; 614} 615 616const HuffmanDecoder& Inflator::GetDistanceDecoder() const 617{ 618 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder; 619} 620 621NAMESPACE_END 622