1// ec2n.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4
5#ifndef CRYPTOPP_IMPORTS
6
7#include "ec2n.h"
8#include "asn.h"
9
10#include "algebra.cpp"
11#include "eprecomp.cpp"
12
13NAMESPACE_BEGIN(CryptoPP)
14
15EC2N::EC2N(BufferedTransformation &bt)
16	: m_field(BERDecodeGF2NP(bt))
17{
18	BERSequenceDecoder seq(bt);
19	m_field->BERDecodeElement(seq, m_a);
20	m_field->BERDecodeElement(seq, m_b);
21	// skip optional seed
22	if (!seq.EndReached())
23		BERDecodeOctetString(seq, TheBitBucket());
24	seq.MessageEnd();
25}
26
27void EC2N::DEREncode(BufferedTransformation &bt) const
28{
29	m_field->DEREncode(bt);
30	DERSequenceEncoder seq(bt);
31	m_field->DEREncodeElement(seq, m_a);
32	m_field->DEREncodeElement(seq, m_b);
33	seq.MessageEnd();
34}
35
36bool EC2N::DecodePoint(EC2N::Point &P, const byte *encodedPoint, size_t encodedPointLen) const
37{
38	StringStore store(encodedPoint, encodedPointLen);
39	return DecodePoint(P, store, encodedPointLen);
40}
41
42bool EC2N::DecodePoint(EC2N::Point &P, BufferedTransformation &bt, size_t encodedPointLen) const
43{
44	byte type;
45	if (encodedPointLen < 1 || !bt.Get(type))
46		return false;
47
48	switch (type)
49	{
50	case 0:
51		P.identity = true;
52		return true;
53	case 2:
54	case 3:
55	{
56		if (encodedPointLen != EncodedPointSize(true))
57			return false;
58
59		P.identity = false;
60		P.x.Decode(bt, m_field->MaxElementByteLength());
61
62		if (P.x.IsZero())
63		{
64			P.y = m_field->SquareRoot(m_b);
65			return true;
66		}
67
68		FieldElement z = m_field->Square(P.x);
69		assert(P.x == m_field->SquareRoot(z));
70		P.y = m_field->Divide(m_field->Add(m_field->Multiply(z, m_field->Add(P.x, m_a)), m_b), z);
71		assert(P.x == m_field->Subtract(m_field->Divide(m_field->Subtract(m_field->Multiply(P.y, z), m_b), z), m_a));
72		z = m_field->SolveQuadraticEquation(P.y);
73		assert(m_field->Add(m_field->Square(z), z) == P.y);
74		z.SetCoefficient(0, type & 1);
75
76		P.y = m_field->Multiply(z, P.x);
77		return true;
78	}
79	case 4:
80	{
81		if (encodedPointLen != EncodedPointSize(false))
82			return false;
83
84		unsigned int len = m_field->MaxElementByteLength();
85		P.identity = false;
86		P.x.Decode(bt, len);
87		P.y.Decode(bt, len);
88		return true;
89	}
90	default:
91		return false;
92	}
93}
94
95void EC2N::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
96{
97	if (P.identity)
98		NullStore().TransferTo(bt, EncodedPointSize(compressed));
99	else if (compressed)
100	{
101		bt.Put(2 + (!P.x ? 0 : m_field->Divide(P.y, P.x).GetBit(0)));
102		P.x.Encode(bt, m_field->MaxElementByteLength());
103	}
104	else
105	{
106		unsigned int len = m_field->MaxElementByteLength();
107		bt.Put(4);	// uncompressed
108		P.x.Encode(bt, len);
109		P.y.Encode(bt, len);
110	}
111}
112
113void EC2N::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
114{
115	ArraySink sink(encodedPoint, EncodedPointSize(compressed));
116	EncodePoint(sink, P, compressed);
117	assert(sink.TotalPutLength() == EncodedPointSize(compressed));
118}
119
120EC2N::Point EC2N::BERDecodePoint(BufferedTransformation &bt) const
121{
122	SecByteBlock str;
123	BERDecodeOctetString(bt, str);
124	Point P;
125	if (!DecodePoint(P, str, str.size()))
126		BERDecodeError();
127	return P;
128}
129
130void EC2N::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
131{
132	SecByteBlock str(EncodedPointSize(compressed));
133	EncodePoint(str, P, compressed);
134	DEREncodeOctetString(bt, str);
135}
136
137bool EC2N::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
138{
139	bool pass = !!m_b;
140	pass = pass && m_a.CoefficientCount() <= m_field->MaxElementBitLength();
141	pass = pass && m_b.CoefficientCount() <= m_field->MaxElementBitLength();
142
143	if (level >= 1)
144		pass = pass && m_field->GetModulus().IsIrreducible();
145
146	return pass;
147}
148
149bool EC2N::VerifyPoint(const Point &P) const
150{
151	const FieldElement &x = P.x, &y = P.y;
152	return P.identity ||
153		(x.CoefficientCount() <= m_field->MaxElementBitLength()
154		&& y.CoefficientCount() <= m_field->MaxElementBitLength()
155		&& !(((x+m_a)*x*x+m_b-(x+y)*y)%m_field->GetModulus()));
156}
157
158bool EC2N::Equal(const Point &P, const Point &Q) const
159{
160	if (P.identity && Q.identity)
161		return true;
162
163	if (P.identity && !Q.identity)
164		return false;
165
166	if (!P.identity && Q.identity)
167		return false;
168
169	return (m_field->Equal(P.x,Q.x) && m_field->Equal(P.y,Q.y));
170}
171
172const EC2N::Point& EC2N::Identity() const
173{
174	return Singleton<Point>().Ref();
175}
176
177const EC2N::Point& EC2N::Inverse(const Point &P) const
178{
179	if (P.identity)
180		return P;
181	else
182	{
183		m_R.identity = false;
184		m_R.y = m_field->Add(P.x, P.y);
185		m_R.x = P.x;
186		return m_R;
187	}
188}
189
190const EC2N::Point& EC2N::Add(const Point &P, const Point &Q) const
191{
192	if (P.identity) return Q;
193	if (Q.identity) return P;
194	if (Equal(P, Q)) return Double(P);
195	if (m_field->Equal(P.x, Q.x) && m_field->Equal(P.y, m_field->Add(Q.x, Q.y))) return Identity();
196
197	FieldElement t = m_field->Add(P.y, Q.y);
198	t = m_field->Divide(t, m_field->Add(P.x, Q.x));
199	FieldElement x = m_field->Square(t);
200	m_field->Accumulate(x, t);
201	m_field->Accumulate(x, Q.x);
202	m_field->Accumulate(x, m_a);
203	m_R.y = m_field->Add(P.y, m_field->Multiply(t, x));
204	m_field->Accumulate(x, P.x);
205	m_field->Accumulate(m_R.y, x);
206
207	m_R.x.swap(x);
208	m_R.identity = false;
209	return m_R;
210}
211
212const EC2N::Point& EC2N::Double(const Point &P) const
213{
214	if (P.identity) return P;
215	if (!m_field->IsUnit(P.x)) return Identity();
216
217	FieldElement t = m_field->Divide(P.y, P.x);
218	m_field->Accumulate(t, P.x);
219	m_R.y = m_field->Square(P.x);
220	m_R.x = m_field->Square(t);
221	m_field->Accumulate(m_R.x, t);
222	m_field->Accumulate(m_R.x, m_a);
223	m_field->Accumulate(m_R.y, m_field->Multiply(t, m_R.x));
224	m_field->Accumulate(m_R.y, m_R.x);
225
226	m_R.identity = false;
227	return m_R;
228}
229
230// ********************************************************
231
232/*
233EcPrecomputation<EC2N>& EcPrecomputation<EC2N>::operator=(const EcPrecomputation<EC2N> &rhs)
234{
235	m_ec = rhs.m_ec;
236	m_ep = rhs.m_ep;
237	m_ep.m_group = m_ec.get();
238	return *this;
239}
240
241void EcPrecomputation<EC2N>::SetCurveAndBase(const EC2N &ec, const EC2N::Point &base)
242{
243	m_ec.reset(new EC2N(ec));
244	m_ep.SetGroupAndBase(*m_ec, base);
245}
246
247void EcPrecomputation<EC2N>::Precompute(unsigned int maxExpBits, unsigned int storage)
248{
249	m_ep.Precompute(maxExpBits, storage);
250}
251
252void EcPrecomputation<EC2N>::Load(BufferedTransformation &bt)
253{
254	BERSequenceDecoder seq(bt);
255	word32 version;
256	BERDecodeUnsigned<word32>(seq, version, INTEGER, 1, 1);
257	m_ep.m_exponentBase.BERDecode(seq);
258	m_ep.m_windowSize = m_ep.m_exponentBase.BitCount() - 1;
259	m_ep.m_bases.clear();
260	while (!seq.EndReached())
261		m_ep.m_bases.push_back(m_ec->BERDecodePoint(seq));
262	seq.MessageEnd();
263}
264
265void EcPrecomputation<EC2N>::Save(BufferedTransformation &bt) const
266{
267	DERSequenceEncoder seq(bt);
268	DEREncodeUnsigned<word32>(seq, 1);	// version
269	m_ep.m_exponentBase.DEREncode(seq);
270	for (unsigned i=0; i<m_ep.m_bases.size(); i++)
271		m_ec->DEREncodePoint(seq, m_ep.m_bases[i]);
272	seq.MessageEnd();
273}
274
275EC2N::Point EcPrecomputation<EC2N>::Exponentiate(const Integer &exponent) const
276{
277	return m_ep.Exponentiate(exponent);
278}
279
280EC2N::Point EcPrecomputation<EC2N>::CascadeExponentiate(const Integer &exponent, const DL_FixedBasePrecomputation<Element> &pc2, const Integer &exponent2) const
281{
282	return m_ep.CascadeExponentiate(exponent, static_cast<const EcPrecomputation<EC2N> &>(pc2).m_ep, exponent2);
283}
284*/
285
286NAMESPACE_END
287
288#endif
289