1// ecp.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4
5#ifndef CRYPTOPP_IMPORTS
6
7#include "ecp.h"
8#include "asn.h"
9#include "nbtheory.h"
10
11#include "algebra.cpp"
12
13NAMESPACE_BEGIN(CryptoPP)
14
15ANONYMOUS_NAMESPACE_BEGIN
16static inline ECP::Point ToMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
17{
18	return P.identity ? P : ECP::Point(mr.ConvertIn(P.x), mr.ConvertIn(P.y));
19}
20
21static inline ECP::Point FromMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
22{
23	return P.identity ? P : ECP::Point(mr.ConvertOut(P.x), mr.ConvertOut(P.y));
24}
25NAMESPACE_END
26
27ECP::ECP(const ECP &ecp, bool convertToMontgomeryRepresentation)
28{
29	if (convertToMontgomeryRepresentation && !ecp.GetField().IsMontgomeryRepresentation())
30	{
31		m_fieldPtr.reset(new MontgomeryRepresentation(ecp.GetField().GetModulus()));
32		m_a = GetField().ConvertIn(ecp.m_a);
33		m_b = GetField().ConvertIn(ecp.m_b);
34	}
35	else
36		operator=(ecp);
37}
38
39ECP::ECP(BufferedTransformation &bt)
40	: m_fieldPtr(new Field(bt))
41{
42	BERSequenceDecoder seq(bt);
43	GetField().BERDecodeElement(seq, m_a);
44	GetField().BERDecodeElement(seq, m_b);
45	// skip optional seed
46	if (!seq.EndReached())
47	{
48		SecByteBlock seed;
49		unsigned int unused;
50		BERDecodeBitString(seq, seed, unused);
51	}
52	seq.MessageEnd();
53}
54
55void ECP::DEREncode(BufferedTransformation &bt) const
56{
57	GetField().DEREncode(bt);
58	DERSequenceEncoder seq(bt);
59	GetField().DEREncodeElement(seq, m_a);
60	GetField().DEREncodeElement(seq, m_b);
61	seq.MessageEnd();
62}
63
64bool ECP::DecodePoint(ECP::Point &P, const byte *encodedPoint, size_t encodedPointLen) const
65{
66	StringStore store(encodedPoint, encodedPointLen);
67	return DecodePoint(P, store, encodedPointLen);
68}
69
70bool ECP::DecodePoint(ECP::Point &P, BufferedTransformation &bt, size_t encodedPointLen) const
71{
72	byte type;
73	if (encodedPointLen < 1 || !bt.Get(type))
74		return false;
75
76	switch (type)
77	{
78	case 0:
79		P.identity = true;
80		return true;
81	case 2:
82	case 3:
83	{
84		if (encodedPointLen != EncodedPointSize(true))
85			return false;
86
87		Integer p = FieldSize();
88
89		P.identity = false;
90		P.x.Decode(bt, GetField().MaxElementByteLength());
91		P.y = ((P.x*P.x+m_a)*P.x+m_b) % p;
92
93		if (Jacobi(P.y, p) !=1)
94			return false;
95
96		P.y = ModularSquareRoot(P.y, p);
97
98		if ((type & 1) != P.y.GetBit(0))
99			P.y = p-P.y;
100
101		return true;
102	}
103	case 4:
104	{
105		if (encodedPointLen != EncodedPointSize(false))
106			return false;
107
108		unsigned int len = GetField().MaxElementByteLength();
109		P.identity = false;
110		P.x.Decode(bt, len);
111		P.y.Decode(bt, len);
112		return true;
113	}
114	default:
115		return false;
116	}
117}
118
119void ECP::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
120{
121	if (P.identity)
122		NullStore().TransferTo(bt, EncodedPointSize(compressed));
123	else if (compressed)
124	{
125		bt.Put(2 + P.y.GetBit(0));
126		P.x.Encode(bt, GetField().MaxElementByteLength());
127	}
128	else
129	{
130		unsigned int len = GetField().MaxElementByteLength();
131		bt.Put(4);	// uncompressed
132		P.x.Encode(bt, len);
133		P.y.Encode(bt, len);
134	}
135}
136
137void ECP::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
138{
139	ArraySink sink(encodedPoint, EncodedPointSize(compressed));
140	EncodePoint(sink, P, compressed);
141	assert(sink.TotalPutLength() == EncodedPointSize(compressed));
142}
143
144ECP::Point ECP::BERDecodePoint(BufferedTransformation &bt) const
145{
146	SecByteBlock str;
147	BERDecodeOctetString(bt, str);
148	Point P;
149	if (!DecodePoint(P, str, str.size()))
150		BERDecodeError();
151	return P;
152}
153
154void ECP::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
155{
156	SecByteBlock str(EncodedPointSize(compressed));
157	EncodePoint(str, P, compressed);
158	DEREncodeOctetString(bt, str);
159}
160
161bool ECP::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
162{
163	Integer p = FieldSize();
164
165	bool pass = p.IsOdd();
166	pass = pass && !m_a.IsNegative() && m_a<p && !m_b.IsNegative() && m_b<p;
167
168	if (level >= 1)
169		pass = pass && ((4*m_a*m_a*m_a+27*m_b*m_b)%p).IsPositive();
170
171	if (level >= 2)
172		pass = pass && VerifyPrime(rng, p);
173
174	return pass;
175}
176
177bool ECP::VerifyPoint(const Point &P) const
178{
179	const FieldElement &x = P.x, &y = P.y;
180	Integer p = FieldSize();
181	return P.identity ||
182		(!x.IsNegative() && x<p && !y.IsNegative() && y<p
183		&& !(((x*x+m_a)*x+m_b-y*y)%p));
184}
185
186bool ECP::Equal(const Point &P, const Point &Q) const
187{
188	if (P.identity && Q.identity)
189		return true;
190
191	if (P.identity && !Q.identity)
192		return false;
193
194	if (!P.identity && Q.identity)
195		return false;
196
197	return (GetField().Equal(P.x,Q.x) && GetField().Equal(P.y,Q.y));
198}
199
200const ECP::Point& ECP::Identity() const
201{
202	return Singleton<Point>().Ref();
203}
204
205const ECP::Point& ECP::Inverse(const Point &P) const
206{
207	if (P.identity)
208		return P;
209	else
210	{
211		m_R.identity = false;
212		m_R.x = P.x;
213		m_R.y = GetField().Inverse(P.y);
214		return m_R;
215	}
216}
217
218const ECP::Point& ECP::Add(const Point &P, const Point &Q) const
219{
220	if (P.identity) return Q;
221	if (Q.identity) return P;
222	if (GetField().Equal(P.x, Q.x))
223		return GetField().Equal(P.y, Q.y) ? Double(P) : Identity();
224
225	FieldElement t = GetField().Subtract(Q.y, P.y);
226	t = GetField().Divide(t, GetField().Subtract(Q.x, P.x));
227	FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), Q.x);
228	m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
229
230	m_R.x.swap(x);
231	m_R.identity = false;
232	return m_R;
233}
234
235const ECP::Point& ECP::Double(const Point &P) const
236{
237	if (P.identity || P.y==GetField().Identity()) return Identity();
238
239	FieldElement t = GetField().Square(P.x);
240	t = GetField().Add(GetField().Add(GetField().Double(t), t), m_a);
241	t = GetField().Divide(t, GetField().Double(P.y));
242	FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), P.x);
243	m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
244
245	m_R.x.swap(x);
246	m_R.identity = false;
247	return m_R;
248}
249
250template <class T, class Iterator> void ParallelInvert(const AbstractRing<T> &ring, Iterator begin, Iterator end)
251{
252	size_t n = end-begin;
253	if (n == 1)
254		*begin = ring.MultiplicativeInverse(*begin);
255	else if (n > 1)
256	{
257		std::vector<T> vec((n+1)/2);
258		unsigned int i;
259		Iterator it;
260
261		for (i=0, it=begin; i<n/2; i++, it+=2)
262			vec[i] = ring.Multiply(*it, *(it+1));
263		if (n%2 == 1)
264			vec[n/2] = *it;
265
266		ParallelInvert(ring, vec.begin(), vec.end());
267
268		for (i=0, it=begin; i<n/2; i++, it+=2)
269		{
270			if (!vec[i])
271			{
272				*it = ring.MultiplicativeInverse(*it);
273				*(it+1) = ring.MultiplicativeInverse(*(it+1));
274			}
275			else
276			{
277				std::swap(*it, *(it+1));
278				*it = ring.Multiply(*it, vec[i]);
279				*(it+1) = ring.Multiply(*(it+1), vec[i]);
280			}
281		}
282		if (n%2 == 1)
283			*it = vec[n/2];
284	}
285}
286
287struct ProjectivePoint
288{
289	ProjectivePoint() {}
290	ProjectivePoint(const Integer &x, const Integer &y, const Integer &z)
291		: x(x), y(y), z(z)	{}
292
293	Integer x,y,z;
294};
295
296class ProjectiveDoubling
297{
298public:
299	ProjectiveDoubling(const ModularArithmetic &mr, const Integer &m_a, const Integer &m_b, const ECPPoint &Q)
300		: mr(mr), firstDoubling(true), negated(false)
301	{
302		if (Q.identity)
303		{
304			sixteenY4 = P.x = P.y = mr.MultiplicativeIdentity();
305			aZ4 = P.z = mr.Identity();
306		}
307		else
308		{
309			P.x = Q.x;
310			P.y = Q.y;
311			sixteenY4 = P.z = mr.MultiplicativeIdentity();
312			aZ4 = m_a;
313		}
314	}
315
316	void Double()
317	{
318		twoY = mr.Double(P.y);
319		P.z = mr.Multiply(P.z, twoY);
320		fourY2 = mr.Square(twoY);
321		S = mr.Multiply(fourY2, P.x);
322		aZ4 = mr.Multiply(aZ4, sixteenY4);
323		M = mr.Square(P.x);
324		M = mr.Add(mr.Add(mr.Double(M), M), aZ4);
325		P.x = mr.Square(M);
326		mr.Reduce(P.x, S);
327		mr.Reduce(P.x, S);
328		mr.Reduce(S, P.x);
329		P.y = mr.Multiply(M, S);
330		sixteenY4 = mr.Square(fourY2);
331		mr.Reduce(P.y, mr.Half(sixteenY4));
332	}
333
334	const ModularArithmetic &mr;
335	ProjectivePoint P;
336	bool firstDoubling, negated;
337	Integer sixteenY4, aZ4, twoY, fourY2, S, M;
338};
339
340struct ZIterator
341{
342	ZIterator() {}
343	ZIterator(std::vector<ProjectivePoint>::iterator it) : it(it) {}
344	Integer& operator*() {return it->z;}
345	int operator-(ZIterator it2) {return int(it-it2.it);}
346	ZIterator operator+(int i) {return ZIterator(it+i);}
347	ZIterator& operator+=(int i) {it+=i; return *this;}
348	std::vector<ProjectivePoint>::iterator it;
349};
350
351ECP::Point ECP::ScalarMultiply(const Point &P, const Integer &k) const
352{
353	Element result;
354	if (k.BitCount() <= 5)
355		AbstractGroup<ECPPoint>::SimultaneousMultiply(&result, P, &k, 1);
356	else
357		ECP::SimultaneousMultiply(&result, P, &k, 1);
358	return result;
359}
360
361void ECP::SimultaneousMultiply(ECP::Point *results, const ECP::Point &P, const Integer *expBegin, unsigned int expCount) const
362{
363	if (!GetField().IsMontgomeryRepresentation())
364	{
365		ECP ecpmr(*this, true);
366		const ModularArithmetic &mr = ecpmr.GetField();
367		ecpmr.SimultaneousMultiply(results, ToMontgomery(mr, P), expBegin, expCount);
368		for (unsigned int i=0; i<expCount; i++)
369			results[i] = FromMontgomery(mr, results[i]);
370		return;
371	}
372
373	ProjectiveDoubling rd(GetField(), m_a, m_b, P);
374	std::vector<ProjectivePoint> bases;
375	std::vector<WindowSlider> exponents;
376	exponents.reserve(expCount);
377	std::vector<std::vector<word32> > baseIndices(expCount);
378	std::vector<std::vector<bool> > negateBase(expCount);
379	std::vector<std::vector<word32> > exponentWindows(expCount);
380	unsigned int i;
381
382	for (i=0; i<expCount; i++)
383	{
384		assert(expBegin->NotNegative());
385		exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 5));
386		exponents[i].FindNextWindow();
387	}
388
389	unsigned int expBitPosition = 0;
390	bool notDone = true;
391
392	while (notDone)
393	{
394		notDone = false;
395		bool baseAdded = false;
396		for (i=0; i<expCount; i++)
397		{
398			if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
399			{
400				if (!baseAdded)
401				{
402					bases.push_back(rd.P);
403					baseAdded =true;
404				}
405
406				exponentWindows[i].push_back(exponents[i].expWindow);
407				baseIndices[i].push_back((word32)bases.size()-1);
408				negateBase[i].push_back(exponents[i].negateNext);
409
410				exponents[i].FindNextWindow();
411			}
412			notDone = notDone || !exponents[i].finished;
413		}
414
415		if (notDone)
416		{
417			rd.Double();
418			expBitPosition++;
419		}
420	}
421
422	// convert from projective to affine coordinates
423	ParallelInvert(GetField(), ZIterator(bases.begin()), ZIterator(bases.end()));
424	for (i=0; i<bases.size(); i++)
425	{
426		if (bases[i].z.NotZero())
427		{
428			bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
429			bases[i].z = GetField().Square(bases[i].z);
430			bases[i].x = GetField().Multiply(bases[i].x, bases[i].z);
431			bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
432		}
433	}
434
435	std::vector<BaseAndExponent<Point, Integer> > finalCascade;
436	for (i=0; i<expCount; i++)
437	{
438		finalCascade.resize(baseIndices[i].size());
439		for (unsigned int j=0; j<baseIndices[i].size(); j++)
440		{
441			ProjectivePoint &base = bases[baseIndices[i][j]];
442			if (base.z.IsZero())
443				finalCascade[j].base.identity = true;
444			else
445			{
446				finalCascade[j].base.identity = false;
447				finalCascade[j].base.x = base.x;
448				if (negateBase[i][j])
449					finalCascade[j].base.y = GetField().Inverse(base.y);
450				else
451					finalCascade[j].base.y = base.y;
452			}
453			finalCascade[j].exponent = Integer(Integer::POSITIVE, 0, exponentWindows[i][j]);
454		}
455		results[i] = GeneralCascadeMultiplication(*this, finalCascade.begin(), finalCascade.end());
456	}
457}
458
459ECP::Point ECP::CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
460{
461	if (!GetField().IsMontgomeryRepresentation())
462	{
463		ECP ecpmr(*this, true);
464		const ModularArithmetic &mr = ecpmr.GetField();
465		return FromMontgomery(mr, ecpmr.CascadeScalarMultiply(ToMontgomery(mr, P), k1, ToMontgomery(mr, Q), k2));
466	}
467	else
468		return AbstractGroup<Point>::CascadeScalarMultiply(P, k1, Q, k2);
469}
470
471NAMESPACE_END
472
473#endif
474