1// algebra.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4
5#ifndef CRYPTOPP_ALGEBRA_CPP	// SunCC workaround: compiler could cause this file to be included twice
6#define CRYPTOPP_ALGEBRA_CPP
7
8#include "algebra.h"
9#include "integer.h"
10
11#include <vector>
12
13NAMESPACE_BEGIN(CryptoPP)
14
15template <class T> const T& AbstractGroup<T>::Double(const Element &a) const
16{
17	return Add(a, a);
18}
19
20template <class T> const T& AbstractGroup<T>::Subtract(const Element &a, const Element &b) const
21{
22	// make copy of a in case Inverse() overwrites it
23	Element a1(a);
24	return Add(a1, Inverse(b));
25}
26
27template <class T> T& AbstractGroup<T>::Accumulate(Element &a, const Element &b) const
28{
29	return a = Add(a, b);
30}
31
32template <class T> T& AbstractGroup<T>::Reduce(Element &a, const Element &b) const
33{
34	return a = Subtract(a, b);
35}
36
37template <class T> const T& AbstractRing<T>::Square(const Element &a) const
38{
39	return Multiply(a, a);
40}
41
42template <class T> const T& AbstractRing<T>::Divide(const Element &a, const Element &b) const
43{
44	// make copy of a in case MultiplicativeInverse() overwrites it
45	Element a1(a);
46	return Multiply(a1, MultiplicativeInverse(b));
47}
48
49template <class T> const T& AbstractEuclideanDomain<T>::Mod(const Element &a, const Element &b) const
50{
51	Element q;
52	DivisionAlgorithm(result, q, a, b);
53	return result;
54}
55
56template <class T> const T& AbstractEuclideanDomain<T>::Gcd(const Element &a, const Element &b) const
57{
58	Element g[3]={b, a};
59	unsigned int i0=0, i1=1, i2=2;
60
61	while (!Equal(g[i1], this->Identity()))
62	{
63		g[i2] = Mod(g[i0], g[i1]);
64		unsigned int t = i0; i0 = i1; i1 = i2; i2 = t;
65	}
66
67	return result = g[i0];
68}
69
70template <class T> const typename QuotientRing<T>::Element& QuotientRing<T>::MultiplicativeInverse(const Element &a) const
71{
72	Element g[3]={m_modulus, a};
73	Element v[3]={m_domain.Identity(), m_domain.MultiplicativeIdentity()};
74	Element y;
75	unsigned int i0=0, i1=1, i2=2;
76
77	while (!Equal(g[i1], Identity()))
78	{
79		// y = g[i0] / g[i1];
80		// g[i2] = g[i0] % g[i1];
81		m_domain.DivisionAlgorithm(g[i2], y, g[i0], g[i1]);
82		// v[i2] = v[i0] - (v[i1] * y);
83		v[i2] = m_domain.Subtract(v[i0], m_domain.Multiply(v[i1], y));
84		unsigned int t = i0; i0 = i1; i1 = i2; i2 = t;
85	}
86
87	return m_domain.IsUnit(g[i0]) ? m_domain.Divide(v[i0], g[i0]) : m_domain.Identity();
88}
89
90template <class T> T AbstractGroup<T>::ScalarMultiply(const Element &base, const Integer &exponent) const
91{
92	Element result;
93	SimultaneousMultiply(&result, base, &exponent, 1);
94	return result;
95}
96
97template <class T> T AbstractGroup<T>::CascadeScalarMultiply(const Element &x, const Integer &e1, const Element &y, const Integer &e2) const
98{
99	const unsigned expLen = STDMAX(e1.BitCount(), e2.BitCount());
100	if (expLen==0)
101		return Identity();
102
103	const unsigned w = (expLen <= 46 ? 1 : (expLen <= 260 ? 2 : 3));
104	const unsigned tableSize = 1<<w;
105	std::vector<Element> powerTable(tableSize << w);
106
107	powerTable[1] = x;
108	powerTable[tableSize] = y;
109	if (w==1)
110		powerTable[3] = Add(x,y);
111	else
112	{
113		powerTable[2] = Double(x);
114		powerTable[2*tableSize] = Double(y);
115
116		unsigned i, j;
117
118		for (i=3; i<tableSize; i+=2)
119			powerTable[i] = Add(powerTable[i-2], powerTable[2]);
120		for (i=1; i<tableSize; i+=2)
121			for (j=i+tableSize; j<(tableSize<<w); j+=tableSize)
122				powerTable[j] = Add(powerTable[j-tableSize], y);
123
124		for (i=3*tableSize; i<(tableSize<<w); i+=2*tableSize)
125			powerTable[i] = Add(powerTable[i-2*tableSize], powerTable[2*tableSize]);
126		for (i=tableSize; i<(tableSize<<w); i+=2*tableSize)
127			for (j=i+2; j<i+tableSize; j+=2)
128				powerTable[j] = Add(powerTable[j-1], x);
129	}
130
131	Element result;
132	unsigned power1 = 0, power2 = 0, prevPosition = expLen-1;
133	bool firstTime = true;
134
135	for (int i = expLen-1; i>=0; i--)
136	{
137		power1 = 2*power1 + e1.GetBit(i);
138		power2 = 2*power2 + e2.GetBit(i);
139
140		if (i==0 || 2*power1 >= tableSize || 2*power2 >= tableSize)
141		{
142			unsigned squaresBefore = prevPosition-i;
143			unsigned squaresAfter = 0;
144			prevPosition = i;
145			while ((power1 || power2) && power1%2 == 0 && power2%2==0)
146			{
147				power1 /= 2;
148				power2 /= 2;
149				squaresBefore--;
150				squaresAfter++;
151			}
152			if (firstTime)
153			{
154				result = powerTable[(power2<<w) + power1];
155				firstTime = false;
156			}
157			else
158			{
159				while (squaresBefore--)
160					result = Double(result);
161				if (power1 || power2)
162					Accumulate(result, powerTable[(power2<<w) + power1]);
163			}
164			while (squaresAfter--)
165				result = Double(result);
166			power1 = power2 = 0;
167		}
168	}
169	return result;
170}
171
172template <class Element, class Iterator> Element GeneralCascadeMultiplication(const AbstractGroup<Element> &group, Iterator begin, Iterator end)
173{
174	if (end-begin == 1)
175		return group.ScalarMultiply(begin->base, begin->exponent);
176	else if (end-begin == 2)
177		return group.CascadeScalarMultiply(begin->base, begin->exponent, (begin+1)->base, (begin+1)->exponent);
178	else
179	{
180		Integer q, t;
181		Iterator last = end;
182		--last;
183
184		std::make_heap(begin, end);
185		std::pop_heap(begin, end);
186
187		while (!!begin->exponent)
188		{
189			// last->exponent is largest exponent, begin->exponent is next largest
190			t = last->exponent;
191			Integer::Divide(last->exponent, q, t, begin->exponent);
192
193			if (q == Integer::One())
194				group.Accumulate(begin->base, last->base);	// avoid overhead of ScalarMultiply()
195			else
196				group.Accumulate(begin->base, group.ScalarMultiply(last->base, q));
197
198			std::push_heap(begin, end);
199			std::pop_heap(begin, end);
200		}
201
202		return group.ScalarMultiply(last->base, last->exponent);
203	}
204}
205
206struct WindowSlider
207{
208	WindowSlider(const Integer &expIn, bool fastNegate, unsigned int windowSizeIn=0)
209		: exp(expIn), windowModulus(Integer::One()), windowSize(windowSizeIn), windowBegin(0), fastNegate(fastNegate), firstTime(true), finished(false)
210	{
211		if (windowSize == 0)
212		{
213			unsigned int expLen = exp.BitCount();
214			windowSize = expLen <= 17 ? 1 : (expLen <= 24 ? 2 : (expLen <= 70 ? 3 : (expLen <= 197 ? 4 : (expLen <= 539 ? 5 : (expLen <= 1434 ? 6 : 7)))));
215		}
216		windowModulus <<= windowSize;
217	}
218
219	void FindNextWindow()
220	{
221		unsigned int expLen = exp.WordCount() * WORD_BITS;
222		unsigned int skipCount = firstTime ? 0 : windowSize;
223		firstTime = false;
224		while (!exp.GetBit(skipCount))
225		{
226			if (skipCount >= expLen)
227			{
228				finished = true;
229				return;
230			}
231			skipCount++;
232		}
233
234		exp >>= skipCount;
235		windowBegin += skipCount;
236		expWindow = word32(exp % (word(1) << windowSize));
237
238		if (fastNegate && exp.GetBit(windowSize))
239		{
240			negateNext = true;
241			expWindow = (word32(1) << windowSize) - expWindow;
242			exp += windowModulus;
243		}
244		else
245			negateNext = false;
246	}
247
248	Integer exp, windowModulus;
249	unsigned int windowSize, windowBegin;
250	word32 expWindow;
251	bool fastNegate, negateNext, firstTime, finished;
252};
253
254template <class T>
255void AbstractGroup<T>::SimultaneousMultiply(T *results, const T &base, const Integer *expBegin, unsigned int expCount) const
256{
257	std::vector<std::vector<Element> > buckets(expCount);
258	std::vector<WindowSlider> exponents;
259	exponents.reserve(expCount);
260	unsigned int i;
261
262	for (i=0; i<expCount; i++)
263	{
264		assert(expBegin->NotNegative());
265		exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 0));
266		exponents[i].FindNextWindow();
267		buckets[i].resize(1<<(exponents[i].windowSize-1), Identity());
268	}
269
270	unsigned int expBitPosition = 0;
271	Element g = base;
272	bool notDone = true;
273
274	while (notDone)
275	{
276		notDone = false;
277		for (i=0; i<expCount; i++)
278		{
279			if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
280			{
281				Element &bucket = buckets[i][exponents[i].expWindow/2];
282				if (exponents[i].negateNext)
283					Accumulate(bucket, Inverse(g));
284				else
285					Accumulate(bucket, g);
286				exponents[i].FindNextWindow();
287			}
288			notDone = notDone || !exponents[i].finished;
289		}
290
291		if (notDone)
292		{
293			g = Double(g);
294			expBitPosition++;
295		}
296	}
297
298	for (i=0; i<expCount; i++)
299	{
300		Element &r = *results++;
301		r = buckets[i][buckets[i].size()-1];
302		if (buckets[i].size() > 1)
303		{
304			for (int j = (int)buckets[i].size()-2; j >= 1; j--)
305			{
306				Accumulate(buckets[i][j], buckets[i][j+1]);
307				Accumulate(r, buckets[i][j]);
308			}
309			Accumulate(buckets[i][0], buckets[i][1]);
310			r = Add(Double(r), buckets[i][0]);
311		}
312	}
313}
314
315template <class T> T AbstractRing<T>::Exponentiate(const Element &base, const Integer &exponent) const
316{
317	Element result;
318	SimultaneousExponentiate(&result, base, &exponent, 1);
319	return result;
320}
321
322template <class T> T AbstractRing<T>::CascadeExponentiate(const Element &x, const Integer &e1, const Element &y, const Integer &e2) const
323{
324	return MultiplicativeGroup().AbstractGroup<T>::CascadeScalarMultiply(x, e1, y, e2);
325}
326
327template <class Element, class Iterator> Element GeneralCascadeExponentiation(const AbstractRing<Element> &ring, Iterator begin, Iterator end)
328{
329	return GeneralCascadeMultiplication<Element>(ring.MultiplicativeGroup(), begin, end);
330}
331
332template <class T>
333void AbstractRing<T>::SimultaneousExponentiate(T *results, const T &base, const Integer *exponents, unsigned int expCount) const
334{
335	MultiplicativeGroup().AbstractGroup<T>::SimultaneousMultiply(results, base, exponents, expCount);
336}
337
338NAMESPACE_END
339
340#endif
341