Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

zinflate.cpp

00001 // zinflate.cpp - written and placed in the public domain by Wei Dai 00002 00003 // This is a complete reimplementation of the DEFLATE decompression algorithm. 00004 // It should not be affected by any security vulnerabilities in the zlib 00005 // compression library. In particular it is not affected by the double free bug 00006 // (http://www.kb.cert.org/vuls/id/368819). 00007 00008 #include "pch.h" 00009 #include "zinflate.h" 00010 00011 NAMESPACE_BEGIN(CryptoPP) 00012 00013 struct CodeLessThan 00014 { 00015 inline bool operator()(const CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) 00016 {return lhs < rhs.code;} 00017 }; 00018 00019 inline bool LowFirstBitReader::FillBuffer(unsigned int length) 00020 { 00021 while (m_bitsBuffered < length) 00022 { 00023 byte b; 00024 if (!m_store.Get(b)) 00025 return false; 00026 m_buffer |= (unsigned long)b << m_bitsBuffered; 00027 m_bitsBuffered += 8; 00028 } 00029 assert(m_bitsBuffered <= sizeof(unsigned long)*8); 00030 return true; 00031 } 00032 00033 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length) 00034 { 00035 bool result = FillBuffer(length); 00036 assert(result); 00037 return m_buffer & (((unsigned long)1 << length) - 1); 00038 } 00039 00040 inline void LowFirstBitReader::SkipBits(unsigned int length) 00041 { 00042 assert(m_bitsBuffered >= length); 00043 m_buffer >>= length; 00044 m_bitsBuffered -= length; 00045 } 00046 00047 inline unsigned long LowFirstBitReader::GetBits(unsigned int length) 00048 { 00049 unsigned long result = PeekBits(length); 00050 SkipBits(length); 00051 return result; 00052 } 00053 00054 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits) 00055 { 00056 return code << (MAX_CODE_BITS - codeBits); 00057 } 00058 00059 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes) 00060 { 00061 // the Huffman codes are represented in 3 ways in this code: 00062 // 00063 // 1. most significant code bit (i.e. top of code tree) in the least significant bit position 00064 // 2. most significant code bit (i.e. top of code tree) in the most significant bit position 00065 // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position, 00066 // where n is the maximum code length for this code tree 00067 // 00068 // (1) is the way the codes come in from the deflate stream 00069 // (2) is used to sort codes so they can be binary searched 00070 // (3) is used in this function to compute codes from code lengths 00071 // 00072 // a code in representation (2) is called "normalized" here 00073 // The BitReverse() function is used to convert between (1) and (2) 00074 // The NormalizeCode() function is used to convert from (3) to (2) 00075 00076 if (nCodes == 0) 00077 throw Err("null code"); 00078 00079 m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes); 00080 00081 if (m_maxCodeBits > MAX_CODE_BITS) 00082 throw Err("code length exceeds maximum"); 00083 00084 if (m_maxCodeBits == 0) 00085 throw Err("null code"); 00086 00087 // count number of codes of each length 00088 SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1); 00089 std::fill(blCount.begin(), blCount.end(), 0); 00090 unsigned int i; 00091 for (i=0; i<nCodes; i++) 00092 blCount[codeBits[i]]++; 00093 00094 // compute the starting code of each length 00095 code_t code = 0; 00096 SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1); 00097 nextCode[1] = 0; 00098 for (i=2; i<=m_maxCodeBits; i++) 00099 { 00100 // compute this while checking for overflow: code = (code + blCount[i-1]) << 1 00101 if (code > code + blCount[i-1]) 00102 throw Err("codes oversubscribed"); 00103 code += blCount[i-1]; 00104 if (code > (code << 1)) 00105 throw Err("codes oversubscribed"); 00106 code <<= 1; 00107 nextCode[i] = code; 00108 } 00109 00110 if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 00111 throw Err("codes oversubscribed"); 00112 else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) 00113 throw Err("codes incomplete"); 00114 00115 // compute a vector of <code, length, value> triples sorted by code 00116 m_codeToValue.resize(nCodes - blCount[0]); 00117 unsigned int j=0; 00118 for (i=0; i<nCodes; i++) 00119 { 00120 unsigned int len = codeBits[i]; 00121 if (len != 0) 00122 { 00123 code = NormalizeCode(nextCode[len]++, len); 00124 m_codeToValue[j].code = code; 00125 m_codeToValue[j].len = len; 00126 m_codeToValue[j].value = i; 00127 j++; 00128 } 00129 } 00130 std::sort(m_codeToValue.begin(), m_codeToValue.end()); 00131 00132 // initialize the decoding cache 00133 m_cacheBits = STDMIN(9U, m_maxCodeBits); 00134 m_cacheMask = (1 << m_cacheBits) - 1; 00135 m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits); 00136 assert(m_normalizedCacheMask == BitReverse(m_cacheMask)); 00137 00138 if (m_cache.size() != 1 << m_cacheBits) 00139 m_cache.resize(1 << m_cacheBits); 00140 00141 for (i=0; i<m_cache.size(); i++) 00142 m_cache[i].type = 0; 00143 } 00144 00145 void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const 00146 { 00147 normalizedCode &= m_normalizedCacheMask; 00148 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1); 00149 if (codeInfo.len <= m_cacheBits) 00150 { 00151 entry.type = 1; 00152 entry.value = codeInfo.value; 00153 entry.len = codeInfo.len; 00154 } 00155 else 00156 { 00157 entry.begin = &codeInfo; 00158 const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1); 00159 if (codeInfo.len == last->len) 00160 { 00161 entry.type = 2; 00162 entry.len = codeInfo.len; 00163 } 00164 else 00165 { 00166 entry.type = 3; 00167 entry.end = last+1; 00168 } 00169 } 00170 } 00171 00172 inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const 00173 { 00174 assert(m_codeToValue.size() > 0); 00175 LookupEntry &entry = m_cache[code & m_cacheMask]; 00176 00177 code_t normalizedCode; 00178 if (entry.type != 1) 00179 normalizedCode = BitReverse(code); 00180 00181 if (entry.type == 0) 00182 FillCacheEntry(entry, normalizedCode); 00183 00184 if (entry.type == 1) 00185 { 00186 value = entry.value; 00187 return entry.len; 00188 } 00189 else 00190 { 00191 const CodeInfo &codeInfo = (entry.type == 2) 00192 ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))] 00193 : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1); 00194 value = codeInfo.value; 00195 return codeInfo.len; 00196 } 00197 } 00198 00199 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const 00200 { 00201 reader.FillBuffer(m_maxCodeBits); 00202 unsigned int codeBits = Decode(reader.PeekBuffer(), value); 00203 if (codeBits > reader.BitsBuffered()) 00204 return false; 00205 reader.SkipBits(codeBits); 00206 return true; 00207 } 00208 00209 // ************************************************************* 00210 00211 Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation) 00212 : AutoSignaling<Filter>(attachment, propagation) 00213 , m_state(PRE_STREAM), m_repeat(repeat) 00214 , m_decodersInitializedWithFixedCodes(false), m_reader(m_inQueue) 00215 { 00216 } 00217 00218 void Inflator::IsolatedInitialize(const NameValuePairs &parameters) 00219 { 00220 m_state = PRE_STREAM; 00221 parameters.GetValue("Repeat", m_repeat); 00222 m_inQueue.Clear(); 00223 m_reader.SkipBits(m_reader.BitsBuffered()); 00224 } 00225 00226 inline void Inflator::OutputByte(byte b) 00227 { 00228 m_window[m_current++] = b; 00229 if (m_current == m_window.size()) 00230 { 00231 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); 00232 m_lastFlush = 0; 00233 m_current = 0; 00234 } 00235 if (m_maxDistance < m_window.size()) 00236 m_maxDistance++; 00237 } 00238 00239 void Inflator::OutputString(const byte *string, unsigned int length) 00240 { 00241 while (length--) 00242 OutputByte(*string++); 00243 } 00244 00245 void Inflator::OutputPast(unsigned int length, unsigned int distance) 00246 { 00247 if (distance > m_maxDistance) 00248 throw BadBlockErr(); 00249 unsigned int start; 00250 if (m_current > distance) 00251 start = m_current - distance; 00252 else 00253 start = m_current + m_window.size() - distance; 00254 00255 if (start + length > m_window.size()) 00256 { 00257 for (; start < m_window.size(); start++, length--) 00258 OutputByte(m_window[start]); 00259 start = 0; 00260 } 00261 00262 if (start + length > m_current || m_current + length >= m_window.size()) 00263 { 00264 while (length--) 00265 OutputByte(m_window[start++]); 00266 } 00267 else 00268 { 00269 memcpy(m_window + m_current, m_window + start, length); 00270 m_current += length; 00271 m_maxDistance = STDMIN((unsigned int)m_window.size(), m_maxDistance + length); 00272 } 00273 } 00274 00275 unsigned int Inflator::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) 00276 { 00277 if (!blocking) 00278 throw BlockingInputOnly("Inflator"); 00279 00280 LazyPutter lp(m_inQueue, inString, length); 00281 ProcessInput(messageEnd != 0); 00282 00283 if (messageEnd) 00284 if (!(m_state == PRE_STREAM || m_state == AFTER_END)) 00285 throw UnexpectedEndErr(); 00286 00287 Output(0, NULL, 0, messageEnd, blocking); 00288 return 0; 00289 } 00290 00291 bool Inflator::IsolatedFlush(bool hardFlush, bool blocking) 00292 { 00293 if (!blocking) 00294 throw BlockingInputOnly("Inflator"); 00295 00296 if (hardFlush) 00297 ProcessInput(true); 00298 FlushOutput(); 00299 00300 return false; 00301 } 00302 00303 void Inflator::ProcessInput(bool flush) 00304 { 00305 while (true) 00306 { 00307 if (m_inQueue.IsEmpty()) 00308 return; 00309 00310 switch (m_state) 00311 { 00312 case PRE_STREAM: 00313 if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize()) 00314 return; 00315 ProcessPrestreamHeader(); 00316 m_state = WAIT_HEADER; 00317 m_maxDistance = 0; 00318 m_current = 0; 00319 m_lastFlush = 0; 00320 m_window.New(1 << GetLog2WindowSize()); 00321 break; 00322 case WAIT_HEADER: 00323 { 00324 // maximum number of bytes before actual compressed data starts 00325 const unsigned int MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15); 00326 if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE)) 00327 return; 00328 DecodeHeader(); 00329 break; 00330 } 00331 case DECODING_BODY: 00332 if (!DecodeBody()) 00333 return; 00334 break; 00335 case POST_STREAM: 00336 if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize()) 00337 return; 00338 ProcessPoststreamTail(); 00339 m_state = m_repeat ? PRE_STREAM : AFTER_END; 00340 Output(0, NULL, 0, GetAutoSignalPropagation(), true); // TODO: non-blocking 00341 break; 00342 case AFTER_END: 00343 m_inQueue.TransferTo(*AttachedTransformation()); 00344 return; 00345 } 00346 } 00347 } 00348 00349 void Inflator::DecodeHeader() 00350 { 00351 if (!m_reader.FillBuffer(3)) 00352 throw UnexpectedEndErr(); 00353 m_eof = m_reader.GetBits(1) != 0; 00354 m_blockType = (byte)m_reader.GetBits(2); 00355 switch (m_blockType) 00356 { 00357 case 0: // stored 00358 { 00359 m_reader.SkipBits(m_reader.BitsBuffered() % 8); 00360 if (!m_reader.FillBuffer(32)) 00361 throw UnexpectedEndErr(); 00362 m_storedLen = (word16)m_reader.GetBits(16); 00363 word16 nlen = (word16)m_reader.GetBits(16); 00364 if (nlen != (word16)~m_storedLen) 00365 throw BadBlockErr(); 00366 break; 00367 } 00368 case 1: // fixed codes 00369 if (!m_decodersInitializedWithFixedCodes) 00370 { 00371 unsigned int codeLengths[288]; 00372 std::fill(codeLengths + 0, codeLengths + 144, 8); 00373 std::fill(codeLengths + 144, codeLengths + 256, 9); 00374 std::fill(codeLengths + 256, codeLengths + 280, 7); 00375 std::fill(codeLengths + 280, codeLengths + 288, 8); 00376 m_literalDecoder.Initialize(codeLengths, 288); 00377 std::fill(codeLengths + 0, codeLengths + 32, 5); 00378 m_distanceDecoder.Initialize(codeLengths, 32); 00379 m_decodersInitializedWithFixedCodes = true; 00380 } 00381 m_nextDecode = LITERAL; 00382 break; 00383 case 2: // dynamic codes 00384 { 00385 m_decodersInitializedWithFixedCodes = false; 00386 if (!m_reader.FillBuffer(5+5+4)) 00387 throw UnexpectedEndErr(); 00388 unsigned int hlit = m_reader.GetBits(5); 00389 unsigned int hdist = m_reader.GetBits(5); 00390 unsigned int hclen = m_reader.GetBits(4); 00391 00392 FixedSizeSecBlock<unsigned int, 286+32> codeLengths; 00393 unsigned int i; 00394 static const unsigned int border[] = { // Order of the bit length code lengths 00395 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; 00396 std::fill(codeLengths.begin(), codeLengths+19, 0); 00397 for (i=0; i<hclen+4; i++) 00398 codeLengths[border[i]] = m_reader.GetBits(3); 00399 00400 try 00401 { 00402 HuffmanDecoder codeLengthDecoder(codeLengths, 19); 00403 for (i = 0; i < hlit+257+hdist+1; ) 00404 { 00405 unsigned int k, count, repeater; 00406 bool result = codeLengthDecoder.Decode(m_reader, k); 00407 if (!result) 00408 throw UnexpectedEndErr(); 00409 if (k <= 15) 00410 { 00411 count = 1; 00412 repeater = k; 00413 } 00414 else switch (k) 00415 { 00416 case 16: 00417 if (!m_reader.FillBuffer(2)) 00418 throw UnexpectedEndErr(); 00419 count = 3 + m_reader.GetBits(2); 00420 if (i == 0) 00421 throw BadBlockErr(); 00422 repeater = codeLengths[i-1]; 00423 break; 00424 case 17: 00425 if (!m_reader.FillBuffer(3)) 00426 throw UnexpectedEndErr(); 00427 count = 3 + m_reader.GetBits(3); 00428 repeater = 0; 00429 break; 00430 case 18: 00431 if (!m_reader.FillBuffer(7)) 00432 throw UnexpectedEndErr(); 00433 count = 11 + m_reader.GetBits(7); 00434 repeater = 0; 00435 break; 00436 } 00437 if (i + count > hlit+257+hdist+1) 00438 throw BadBlockErr(); 00439 std::fill(codeLengths + i, codeLengths + i + count, repeater); 00440 i += count; 00441 } 00442 m_literalDecoder.Initialize(codeLengths, hlit+257); 00443 if (hdist == 0 && codeLengths[hlit+257] == 0) 00444 { 00445 if (hlit != 0) // a single zero distance code length means all literals 00446 throw BadBlockErr(); 00447 } 00448 else 00449 m_distanceDecoder.Initialize(codeLengths+hlit+257, hdist+1); 00450 m_nextDecode = LITERAL; 00451 } 00452 catch (HuffmanDecoder::Err &) 00453 { 00454 throw BadBlockErr(); 00455 } 00456 break; 00457 } 00458 default: 00459 throw BadBlockErr(); // reserved block type 00460 } 00461 m_state = DECODING_BODY; 00462 } 00463 00464 bool Inflator::DecodeBody() 00465 { 00466 bool blockEnd = false; 00467 switch (m_blockType) 00468 { 00469 case 0: // stored 00470 assert(m_reader.BitsBuffered() == 0); 00471 while (!m_inQueue.IsEmpty() && !blockEnd) 00472 { 00473 unsigned int size; 00474 const byte *block = m_inQueue.Spy(size); 00475 size = STDMIN(size, (unsigned int)m_storedLen); 00476 OutputString(block, size); 00477 m_inQueue.Skip(size); 00478 m_storedLen -= size; 00479 if (m_storedLen == 0) 00480 blockEnd = true; 00481 } 00482 break; 00483 case 1: // fixed codes 00484 case 2: // dynamic codes 00485 static const unsigned int lengthStarts[] = { 00486 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 00487 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258}; 00488 static const unsigned int lengthExtraBits[] = { 00489 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 00490 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0}; 00491 static const unsigned int distanceStarts[] = { 00492 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 00493 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 00494 8193, 12289, 16385, 24577}; 00495 static const unsigned int distanceExtraBits[] = { 00496 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 00497 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 00498 12, 12, 13, 13}; 00499 00500 switch (m_nextDecode) 00501 { 00502 while (true) 00503 { 00504 case LITERAL: 00505 if (!m_literalDecoder.Decode(m_reader, m_literal)) 00506 { 00507 m_nextDecode = LITERAL; 00508 break; 00509 } 00510 if (m_literal < 256) 00511 OutputByte((byte)m_literal); 00512 else if (m_literal == 256) // end of block 00513 { 00514 blockEnd = true; 00515 break; 00516 } 00517 else 00518 { 00519 if (m_literal > 285) 00520 throw BadBlockErr(); 00521 unsigned int bits; 00522 case LENGTH_BITS: 00523 bits = lengthExtraBits[m_literal-257]; 00524 if (!m_reader.FillBuffer(bits)) 00525 { 00526 m_nextDecode = LENGTH_BITS; 00527 break; 00528 } 00529 m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257]; 00530 case DISTANCE: 00531 if (!m_distanceDecoder.Decode(m_reader, m_distance)) 00532 { 00533 m_nextDecode = DISTANCE; 00534 break; 00535 } 00536 case DISTANCE_BITS: 00537 bits = distanceExtraBits[m_distance]; 00538 if (!m_reader.FillBuffer(bits)) 00539 { 00540 m_nextDecode = DISTANCE_BITS; 00541 break; 00542 } 00543 m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance]; 00544 OutputPast(m_literal, m_distance); 00545 } 00546 } 00547 } 00548 } 00549 if (blockEnd) 00550 { 00551 if (m_eof) 00552 { 00553 FlushOutput(); 00554 m_reader.SkipBits(m_reader.BitsBuffered()%8); 00555 if (m_reader.BitsBuffered()) 00556 { 00557 // undo too much lookahead 00558 SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8); 00559 for (unsigned int i=0; i<buffer.size(); i++) 00560 buffer[i] = (byte)m_reader.GetBits(8); 00561 m_inQueue.Unget(buffer, buffer.size()); 00562 } 00563 m_state = POST_STREAM; 00564 } 00565 else 00566 m_state = WAIT_HEADER; 00567 } 00568 return blockEnd; 00569 } 00570 00571 void Inflator::FlushOutput() 00572 { 00573 if (m_state != PRE_STREAM) 00574 { 00575 assert(m_current >= m_lastFlush); 00576 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush); 00577 m_lastFlush = m_current; 00578 } 00579 } 00580 00581 NAMESPACE_END

Generated on Wed Jul 28 08:07:09 2004 for Crypto++ by doxygen 1.3.7