1// pssr.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4#include "pssr.h"
5#include <functional>
6
7NAMESPACE_BEGIN(CryptoPP)
8
9// more in dll.cpp
10template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31;
11template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32;
12template<> const byte EMSA2HashId<Whirlpool>::id = 0x37;
13
14#ifndef CRYPTOPP_IMPORTS
15
16size_t PSSR_MEM_Base::MinRepresentativeBitLength(size_t hashIdentifierLength, size_t digestLength) const
17{
18	size_t saltLen = SaltLen(digestLength);
19	size_t minPadLen = MinPadLen(digestLength);
20	return 9 + 8*(minPadLen + saltLen + digestLength + hashIdentifierLength);
21}
22
23size_t PSSR_MEM_Base::MaxRecoverableLength(size_t representativeBitLength, size_t hashIdentifierLength, size_t digestLength) const
24{
25	if (AllowRecovery())
26		return SaturatingSubtract(representativeBitLength, MinRepresentativeBitLength(hashIdentifierLength, digestLength)) / 8;
27	return 0;
28}
29
30bool PSSR_MEM_Base::IsProbabilistic() const
31{
32	return SaltLen(1) > 0;
33}
34
35bool PSSR_MEM_Base::AllowNonrecoverablePart() const
36{
37	return true;
38}
39
40bool PSSR_MEM_Base::RecoverablePartFirst() const
41{
42	return false;
43}
44
45void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng,
46	const byte *recoverableMessage, size_t recoverableMessageLength,
47	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
48	byte *representative, size_t representativeBitLength) const
49{
50	assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
51
52	const size_t u = hashIdentifier.second + 1;
53	const size_t representativeByteLength = BitsToBytes(representativeBitLength);
54	const size_t digestSize = hash.DigestSize();
55	const size_t saltSize = SaltLen(digestSize);
56	byte *const h = representative + representativeByteLength - u - digestSize;
57
58	SecByteBlock digest(digestSize), salt(saltSize);
59	hash.Final(digest);
60	rng.GenerateBlock(salt, saltSize);
61
62	// compute H = hash of M'
63	byte c[8];
64	PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
65	PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
66	hash.Update(c, 8);
67	hash.Update(recoverableMessage, recoverableMessageLength);
68	hash.Update(digest, digestSize);
69	hash.Update(salt, saltSize);
70	hash.Final(h);
71
72	// compute representative
73	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false);
74	byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1;
75	xorStart[0] ^= 1;
76	xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength);
77	xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size());
78	memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second);
79	representative[representativeByteLength - 1] = hashIdentifier.second ? 0xcc : 0xbc;
80	if (representativeBitLength % 8 != 0)
81		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
82}
83
84DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative(
85	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
86	byte *representative, size_t representativeBitLength,
87	byte *recoverableMessage) const
88{
89	assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
90
91	const size_t u = hashIdentifier.second + 1;
92	const size_t representativeByteLength = BitsToBytes(representativeBitLength);
93	const size_t digestSize = hash.DigestSize();
94	const size_t saltSize = SaltLen(digestSize);
95	const byte *const h = representative + representativeByteLength - u - digestSize;
96
97	SecByteBlock digest(digestSize);
98	hash.Final(digest);
99
100	DecodingResult result(0);
101	bool &valid = result.isValidCoding;
102	size_t &recoverableMessageLength = result.messageLength;
103
104	valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid;
105	valid = VerifyBufsEqual(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) && valid;
106
107	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize);
108	if (representativeBitLength % 8 != 0)
109		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
110
111	// extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt
112	byte *salt = representative + representativeByteLength - u - digestSize - saltSize;
113	byte *M = std::find_if(representative, salt-1, std::bind2nd(std::not_equal_to<byte>(), 0));
114	recoverableMessageLength = salt-M-1;
115	if (*M == 0x01
116		&& (size_t)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize)
117		&& recoverableMessageLength <= MaxRecoverableLength(representativeBitLength, hashIdentifier.second, digestSize))
118	{
119		memcpy(recoverableMessage, M+1, recoverableMessageLength);
120	}
121	else
122	{
123		recoverableMessageLength = 0;
124		valid = false;
125	}
126
127	// verify H = hash of M'
128	byte c[8];
129	PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
130	PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
131	hash.Update(c, 8);
132	hash.Update(recoverableMessage, recoverableMessageLength);
133	hash.Update(digest, digestSize);
134	hash.Update(salt, saltSize);
135	valid = hash.Verify(h) && valid;
136
137	if (!AllowRecovery() && valid && recoverableMessageLength != 0)
138		{throw NotImplemented("PSSR_MEM: message recovery disabled");}
139
140	return result;
141}
142
143#endif
144
145NAMESPACE_END
146