1// pwdbased.h - written and placed in the public domain by Wei Dai
2
3#ifndef CRYPTOPP_PWDBASED_H
4#define CRYPTOPP_PWDBASED_H
5
6#include "cryptlib.h"
7#include "hmac.h"
8#include "hrtimer.h"
9
10NAMESPACE_BEGIN(CryptoPP)
11
12//! abstract base class for password based key derivation function
13class PasswordBasedKeyDerivationFunction
14{
15public:
16	virtual size_t MaxDerivedKeyLength() const =0;
17	virtual bool UsesPurposeByte() const =0;
18	//! derive key from password
19	/*! If timeInSeconds != 0, will iterate until time elapsed, as measured by ThreadUserTimer
20		Returns actual iteration count, which is equal to iterations if timeInSeconds == 0, and not less than iterations otherwise. */
21	virtual unsigned int DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds=0) const =0;
22};
23
24//! PBKDF1 from PKCS #5, T should be a HashTransformation class
25template <class T>
26class PKCS5_PBKDF1 : public PasswordBasedKeyDerivationFunction
27{
28public:
29	size_t MaxDerivedKeyLength() const {return T::DIGESTSIZE;}
30	bool UsesPurposeByte() const {return false;}
31	// PKCS #5 says PBKDF1 should only take 8-byte salts. This implementation allows salts of any length.
32	unsigned int DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds=0) const;
33};
34
35//! PBKDF2 from PKCS #5, T should be a HashTransformation class
36template <class T>
37class PKCS5_PBKDF2_HMAC : public PasswordBasedKeyDerivationFunction
38{
39public:
40	size_t MaxDerivedKeyLength() const {return 0xffffffffU;}	// should multiply by T::DIGESTSIZE, but gets overflow that way
41	bool UsesPurposeByte() const {return false;}
42	unsigned int DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds=0) const;
43};
44
45/*
46class PBKDF2Params
47{
48public:
49	SecByteBlock m_salt;
50	unsigned int m_interationCount;
51	ASNOptional<ASNUnsignedWrapper<word32> > m_keyLength;
52};
53*/
54
55template <class T>
56unsigned int PKCS5_PBKDF1<T>::DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds) const
57{
58	assert(derivedLen <= MaxDerivedKeyLength());
59	assert(iterations > 0 || timeInSeconds > 0);
60
61	if (!iterations)
62		iterations = 1;
63
64	T hash;
65	hash.Update(password, passwordLen);
66	hash.Update(salt, saltLen);
67
68	SecByteBlock buffer(hash.DigestSize());
69	hash.Final(buffer);
70
71	unsigned int i;
72	ThreadUserTimer timer;
73
74	if (timeInSeconds)
75		timer.StartTimer();
76
77	for (i=1; i<iterations || (timeInSeconds && (i%128!=0 || timer.ElapsedTimeAsDouble() < timeInSeconds)); i++)
78		hash.CalculateDigest(buffer, buffer, buffer.size());
79
80	memcpy(derived, buffer, derivedLen);
81	return i;
82}
83
84template <class T>
85unsigned int PKCS5_PBKDF2_HMAC<T>::DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds) const
86{
87	assert(derivedLen <= MaxDerivedKeyLength());
88	assert(iterations > 0 || timeInSeconds > 0);
89
90	if (!iterations)
91		iterations = 1;
92
93	HMAC<T> hmac(password, passwordLen);
94	SecByteBlock buffer(hmac.DigestSize());
95	ThreadUserTimer timer;
96
97	unsigned int i=1;
98	while (derivedLen > 0)
99	{
100		hmac.Update(salt, saltLen);
101		unsigned int j;
102		for (j=0; j<4; j++)
103		{
104			byte b = byte(i >> ((3-j)*8));
105			hmac.Update(&b, 1);
106		}
107		hmac.Final(buffer);
108
109		size_t segmentLen = STDMIN(derivedLen, buffer.size());
110		memcpy(derived, buffer, segmentLen);
111
112		if (timeInSeconds)
113		{
114			timeInSeconds = timeInSeconds / ((derivedLen + buffer.size() - 1) / buffer.size());
115			timer.StartTimer();
116		}
117
118		for (j=1; j<iterations || (timeInSeconds && (j%128!=0 || timer.ElapsedTimeAsDouble() < timeInSeconds)); j++)
119		{
120			hmac.CalculateDigest(buffer, buffer, buffer.size());
121			xorbuf(derived, buffer, segmentLen);
122		}
123
124		if (timeInSeconds)
125		{
126			iterations = j;
127			timeInSeconds = 0;
128		}
129
130		derived += segmentLen;
131		derivedLen -= segmentLen;
132		i++;
133	}
134
135	return iterations;
136}
137
138//! PBKDF from PKCS #12, appendix B, T should be a HashTransformation class
139template <class T>
140class PKCS12_PBKDF : public PasswordBasedKeyDerivationFunction
141{
142public:
143	size_t MaxDerivedKeyLength() const {return size_t(0)-1;}
144	bool UsesPurposeByte() const {return true;}
145	unsigned int DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds) const;
146};
147
148template <class T>
149unsigned int PKCS12_PBKDF<T>::DeriveKey(byte *derived, size_t derivedLen, byte purpose, const byte *password, size_t passwordLen, const byte *salt, size_t saltLen, unsigned int iterations, double timeInSeconds) const
150{
151	assert(derivedLen <= MaxDerivedKeyLength());
152	assert(iterations > 0 || timeInSeconds > 0);
153
154	if (!iterations)
155		iterations = 1;
156
157	const size_t v = T::BLOCKSIZE;	// v is in bytes rather than bits as in PKCS #12
158	const size_t DLen = v, SLen = RoundUpToMultipleOf(saltLen, v);
159	const size_t PLen = RoundUpToMultipleOf(passwordLen, v), ILen = SLen + PLen;
160	SecByteBlock buffer(DLen + SLen + PLen);
161	byte *D = buffer, *S = buffer+DLen, *P = buffer+DLen+SLen, *I = S;
162
163	memset(D, purpose, DLen);
164	size_t i;
165	for (i=0; i<SLen; i++)
166		S[i] = salt[i % saltLen];
167	for (i=0; i<PLen; i++)
168		P[i] = password[i % passwordLen];
169
170
171	T hash;
172	SecByteBlock Ai(T::DIGESTSIZE), B(v);
173	ThreadUserTimer timer;
174
175	while (derivedLen > 0)
176	{
177		hash.CalculateDigest(Ai, buffer, buffer.size());
178
179		if (timeInSeconds)
180		{
181			timeInSeconds = timeInSeconds / ((derivedLen + Ai.size() - 1) / Ai.size());
182			timer.StartTimer();
183		}
184
185		for (i=1; i<iterations || (timeInSeconds && (i%128!=0 || timer.ElapsedTimeAsDouble() < timeInSeconds)); i++)
186			hash.CalculateDigest(Ai, Ai, Ai.size());
187
188		if (timeInSeconds)
189		{
190			iterations = (unsigned int)i;
191			timeInSeconds = 0;
192		}
193
194		for (i=0; i<B.size(); i++)
195			B[i] = Ai[i % Ai.size()];
196
197		Integer B1(B, B.size());
198		++B1;
199		for (i=0; i<ILen; i+=v)
200			(Integer(I+i, v) + B1).Encode(I+i, v);
201
202		size_t segmentLen = STDMIN(derivedLen, Ai.size());
203		memcpy(derived, Ai, segmentLen);
204		derived += segmentLen;
205		derivedLen -= segmentLen;
206	}
207
208	return iterations;
209}
210
211NAMESPACE_END
212
213#endif
214