1// pssr.cpp - written and placed in the public domain by Wei Dai 2 3#include "pch.h" 4#include "pssr.h" 5#include <functional> 6 7NAMESPACE_BEGIN(CryptoPP) 8 9// more in dll.cpp 10template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31; 11template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32; 12template<> const byte EMSA2HashId<Whirlpool>::id = 0x37; 13 14#ifndef CRYPTOPP_IMPORTS 15 16size_t PSSR_MEM_Base::MinRepresentativeBitLength(size_t hashIdentifierLength, size_t digestLength) const 17{ 18 size_t saltLen = SaltLen(digestLength); 19 size_t minPadLen = MinPadLen(digestLength); 20 return 9 + 8*(minPadLen + saltLen + digestLength + hashIdentifierLength); 21} 22 23size_t PSSR_MEM_Base::MaxRecoverableLength(size_t representativeBitLength, size_t hashIdentifierLength, size_t digestLength) const 24{ 25 if (AllowRecovery()) 26 return SaturatingSubtract(representativeBitLength, MinRepresentativeBitLength(hashIdentifierLength, digestLength)) / 8; 27 return 0; 28} 29 30bool PSSR_MEM_Base::IsProbabilistic() const 31{ 32 return SaltLen(1) > 0; 33} 34 35bool PSSR_MEM_Base::AllowNonrecoverablePart() const 36{ 37 return true; 38} 39 40bool PSSR_MEM_Base::RecoverablePartFirst() const 41{ 42 return false; 43} 44 45void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng, 46 const byte *recoverableMessage, size_t recoverableMessageLength, 47 HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty, 48 byte *representative, size_t representativeBitLength) const 49{ 50 assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize())); 51 52 const size_t u = hashIdentifier.second + 1; 53 const size_t representativeByteLength = BitsToBytes(representativeBitLength); 54 const size_t digestSize = hash.DigestSize(); 55 const size_t saltSize = SaltLen(digestSize); 56 byte *const h = representative + representativeByteLength - u - digestSize; 57 58 SecByteBlock digest(digestSize), salt(saltSize); 59 hash.Final(digest); 60 rng.GenerateBlock(salt, saltSize); 61 62 // compute H = hash of M' 63 byte c[8]; 64 PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength)); 65 PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3)); 66 hash.Update(c, 8); 67 hash.Update(recoverableMessage, recoverableMessageLength); 68 hash.Update(digest, digestSize); 69 hash.Update(salt, saltSize); 70 hash.Final(h); 71 72 // compute representative 73 GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false); 74 byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1; 75 xorStart[0] ^= 1; 76 xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength); 77 xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size()); 78 memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second); 79 representative[representativeByteLength - 1] = hashIdentifier.second ? 0xcc : 0xbc; 80 if (representativeBitLength % 8 != 0) 81 representative[0] = (byte)Crop(representative[0], representativeBitLength % 8); 82} 83 84DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative( 85 HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty, 86 byte *representative, size_t representativeBitLength, 87 byte *recoverableMessage) const 88{ 89 assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize())); 90 91 const size_t u = hashIdentifier.second + 1; 92 const size_t representativeByteLength = BitsToBytes(representativeBitLength); 93 const size_t digestSize = hash.DigestSize(); 94 const size_t saltSize = SaltLen(digestSize); 95 const byte *const h = representative + representativeByteLength - u - digestSize; 96 97 SecByteBlock digest(digestSize); 98 hash.Final(digest); 99 100 DecodingResult result(0); 101 bool &valid = result.isValidCoding; 102 size_t &recoverableMessageLength = result.messageLength; 103 104 valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid; 105 valid = VerifyBufsEqual(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) && valid; 106 107 GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize); 108 if (representativeBitLength % 8 != 0) 109 representative[0] = (byte)Crop(representative[0], representativeBitLength % 8); 110 111 // extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt 112 byte *salt = representative + representativeByteLength - u - digestSize - saltSize; 113 byte *M = std::find_if(representative, salt-1, std::bind2nd(std::not_equal_to<byte>(), 0)); 114 recoverableMessageLength = salt-M-1; 115 if (*M == 0x01 116 && (size_t)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize) 117 && recoverableMessageLength <= MaxRecoverableLength(representativeBitLength, hashIdentifier.second, digestSize)) 118 { 119 memcpy(recoverableMessage, M+1, recoverableMessageLength); 120 } 121 else 122 { 123 recoverableMessageLength = 0; 124 valid = false; 125 } 126 127 // verify H = hash of M' 128 byte c[8]; 129 PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength)); 130 PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3)); 131 hash.Update(c, 8); 132 hash.Update(recoverableMessage, recoverableMessageLength); 133 hash.Update(digest, digestSize); 134 hash.Update(salt, saltSize); 135 valid = hash.Verify(h) && valid; 136 137 if (!AllowRecovery() && valid && recoverableMessageLength != 0) 138 {throw NotImplemented("PSSR_MEM: message recovery disabled");} 139 140 return result; 141} 142 143#endif 144 145NAMESPACE_END 146