1// strciphr.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4
5#ifndef CRYPTOPP_IMPORTS
6
7#include "strciphr.h"
8
9NAMESPACE_BEGIN(CryptoPP)
10
11template <class S>
12void AdditiveCipherTemplate<S>::UncheckedSetKey(const byte *key, unsigned int length, const NameValuePairs &params)
13{
14	PolicyInterface &policy = this->AccessPolicy();
15	policy.CipherSetKey(params, key, length);
16	m_leftOver = 0;
17	unsigned int bufferByteSize = policy.CanOperateKeystream() ? GetBufferByteSize(policy) : RoundUpToMultipleOf(1024U, GetBufferByteSize(policy));
18	m_buffer.New(bufferByteSize);
19
20	if (this->IsResynchronizable())
21	{
22		size_t ivLength;
23		const byte *iv = this->GetIVAndThrowIfInvalid(params, ivLength);
24		policy.CipherResynchronize(m_buffer, iv, ivLength);
25	}
26}
27
28template <class S>
29void AdditiveCipherTemplate<S>::GenerateBlock(byte *outString, size_t length)
30{
31	if (m_leftOver > 0)
32	{
33		size_t len = STDMIN(m_leftOver, length);
34		memcpy(outString, KeystreamBufferEnd()-m_leftOver, len);
35		length -= len;
36		m_leftOver -= len;
37		outString += len;
38
39		if (!length)
40			return;
41	}
42	assert(m_leftOver == 0);
43
44	PolicyInterface &policy = this->AccessPolicy();
45	unsigned int bytesPerIteration = policy.GetBytesPerIteration();
46
47	if (length >= bytesPerIteration)
48	{
49		size_t iterations = length / bytesPerIteration;
50		policy.WriteKeystream(outString, iterations);
51		outString += iterations * bytesPerIteration;
52		length -= iterations * bytesPerIteration;
53	}
54
55	if (length > 0)
56	{
57		size_t bufferByteSize = RoundUpToMultipleOf(length, bytesPerIteration);
58		size_t bufferIterations = bufferByteSize / bytesPerIteration;
59
60		policy.WriteKeystream(KeystreamBufferEnd()-bufferByteSize, bufferIterations);
61		memcpy(outString, KeystreamBufferEnd()-bufferByteSize, length);
62		m_leftOver = bufferByteSize - length;
63	}
64}
65
66template <class S>
67void AdditiveCipherTemplate<S>::ProcessData(byte *outString, const byte *inString, size_t length)
68{
69	if (m_leftOver > 0)
70	{
71		size_t len = STDMIN(m_leftOver, length);
72		xorbuf(outString, inString, KeystreamBufferEnd()-m_leftOver, len);
73		length -= len;
74		m_leftOver -= len;
75		inString += len;
76		outString += len;
77
78		if (!length)
79			return;
80	}
81	assert(m_leftOver == 0);
82
83	PolicyInterface &policy = this->AccessPolicy();
84	unsigned int bytesPerIteration = policy.GetBytesPerIteration();
85
86	if (policy.CanOperateKeystream() && length >= bytesPerIteration)
87	{
88		size_t iterations = length / bytesPerIteration;
89		unsigned int alignment = policy.GetAlignment();
90		KeystreamOperation operation = KeystreamOperation((IsAlignedOn(inString, alignment) * 2) | (int)IsAlignedOn(outString, alignment));
91
92		policy.OperateKeystream(operation, outString, inString, iterations);
93
94		inString += iterations * bytesPerIteration;
95		outString += iterations * bytesPerIteration;
96		length -= iterations * bytesPerIteration;
97
98		if (!length)
99			return;
100	}
101
102	size_t bufferByteSize = m_buffer.size();
103	size_t bufferIterations = bufferByteSize / bytesPerIteration;
104
105	while (length >= bufferByteSize)
106	{
107		policy.WriteKeystream(m_buffer, bufferIterations);
108		xorbuf(outString, inString, KeystreamBufferBegin(), bufferByteSize);
109		length -= bufferByteSize;
110		inString += bufferByteSize;
111		outString += bufferByteSize;
112	}
113
114	if (length > 0)
115	{
116		bufferByteSize = RoundUpToMultipleOf(length, bytesPerIteration);
117		bufferIterations = bufferByteSize / bytesPerIteration;
118
119		policy.WriteKeystream(KeystreamBufferEnd()-bufferByteSize, bufferIterations);
120		xorbuf(outString, inString, KeystreamBufferEnd()-bufferByteSize, length);
121		m_leftOver = bufferByteSize - length;
122	}
123}
124
125template <class S>
126void AdditiveCipherTemplate<S>::Resynchronize(const byte *iv, int length)
127{
128	PolicyInterface &policy = this->AccessPolicy();
129	m_leftOver = 0;
130	m_buffer.New(GetBufferByteSize(policy));
131	policy.CipherResynchronize(m_buffer, iv, this->ThrowIfInvalidIVLength(length));
132}
133
134template <class BASE>
135void AdditiveCipherTemplate<BASE>::Seek(lword position)
136{
137	PolicyInterface &policy = this->AccessPolicy();
138	unsigned int bytesPerIteration = policy.GetBytesPerIteration();
139
140	policy.SeekToIteration(position / bytesPerIteration);
141	position %= bytesPerIteration;
142
143	if (position > 0)
144	{
145		policy.WriteKeystream(KeystreamBufferEnd()-bytesPerIteration, 1);
146		m_leftOver = bytesPerIteration - (unsigned int)position;
147	}
148	else
149		m_leftOver = 0;
150}
151
152template <class BASE>
153void CFB_CipherTemplate<BASE>::UncheckedSetKey(const byte *key, unsigned int length, const NameValuePairs &params)
154{
155	PolicyInterface &policy = this->AccessPolicy();
156	policy.CipherSetKey(params, key, length);
157
158	if (this->IsResynchronizable())
159	{
160		size_t ivLength;
161		const byte *iv = this->GetIVAndThrowIfInvalid(params, ivLength);
162		policy.CipherResynchronize(iv, ivLength);
163	}
164
165	m_leftOver = policy.GetBytesPerIteration();
166}
167
168template <class BASE>
169void CFB_CipherTemplate<BASE>::Resynchronize(const byte *iv, int length)
170{
171	PolicyInterface &policy = this->AccessPolicy();
172	policy.CipherResynchronize(iv, this->ThrowIfInvalidIVLength(length));
173	m_leftOver = policy.GetBytesPerIteration();
174}
175
176template <class BASE>
177void CFB_CipherTemplate<BASE>::ProcessData(byte *outString, const byte *inString, size_t length)
178{
179	assert(length % this->MandatoryBlockSize() == 0);
180
181	PolicyInterface &policy = this->AccessPolicy();
182	unsigned int bytesPerIteration = policy.GetBytesPerIteration();
183	unsigned int alignment = policy.GetAlignment();
184	byte *reg = policy.GetRegisterBegin();
185
186	if (m_leftOver)
187	{
188		size_t len = STDMIN(m_leftOver, length);
189		CombineMessageAndShiftRegister(outString, reg + bytesPerIteration - m_leftOver, inString, len);
190		m_leftOver -= len;
191		length -= len;
192		inString += len;
193		outString += len;
194	}
195
196	if (!length)
197		return;
198
199	assert(m_leftOver == 0);
200
201	if (policy.CanIterate() && length >= bytesPerIteration && IsAlignedOn(outString, alignment))
202	{
203		if (IsAlignedOn(inString, alignment))
204			policy.Iterate(outString, inString, GetCipherDir(*this), length / bytesPerIteration);
205		else
206		{
207			memcpy(outString, inString, length);
208			policy.Iterate(outString, outString, GetCipherDir(*this), length / bytesPerIteration);
209		}
210		inString += length - length % bytesPerIteration;
211		outString += length - length % bytesPerIteration;
212		length %= bytesPerIteration;
213	}
214
215	while (length >= bytesPerIteration)
216	{
217		policy.TransformRegister();
218		CombineMessageAndShiftRegister(outString, reg, inString, bytesPerIteration);
219		length -= bytesPerIteration;
220		inString += bytesPerIteration;
221		outString += bytesPerIteration;
222	}
223
224	if (length > 0)
225	{
226		policy.TransformRegister();
227		CombineMessageAndShiftRegister(outString, reg, inString, length);
228		m_leftOver = bytesPerIteration - length;
229	}
230}
231
232template <class BASE>
233void CFB_EncryptionTemplate<BASE>::CombineMessageAndShiftRegister(byte *output, byte *reg, const byte *message, size_t length)
234{
235	xorbuf(reg, message, length);
236	memcpy(output, reg, length);
237}
238
239template <class BASE>
240void CFB_DecryptionTemplate<BASE>::CombineMessageAndShiftRegister(byte *output, byte *reg, const byte *message, size_t length)
241{
242	for (unsigned int i=0; i<length; i++)
243	{
244		byte b = message[i];
245		output[i] = reg[i] ^ b;
246		reg[i] = b;
247	}
248}
249
250NAMESPACE_END
251
252#endif
253