1// polynomi.cpp - written and placed in the public domain by Wei Dai
2
3// Part of the code for polynomial evaluation and interpolation
4// originally came from Hal Finney's public domain secsplit.c.
5
6#include "pch.h"
7#include "polynomi.h"
8#include "secblock.h"
9
10#include <sstream>
11#include <iostream>
12
13NAMESPACE_BEGIN(CryptoPP)
14
15template <class T>
16void PolynomialOver<T>::Randomize(RandomNumberGenerator &rng, const RandomizationParameter &parameter, const Ring &ring)
17{
18	m_coefficients.resize(parameter.m_coefficientCount);
19	for (unsigned int i=0; i<m_coefficients.size(); ++i)
20		m_coefficients[i] = ring.RandomElement(rng, parameter.m_coefficientParameter);
21}
22
23template <class T>
24void PolynomialOver<T>::FromStr(const char *str, const Ring &ring)
25{
26	std::istringstream in((char *)str);
27	bool positive = true;
28	CoefficientType coef;
29	unsigned int power;
30
31	while (in)
32	{
33		std::ws(in);
34		if (in.peek() == 'x')
35			coef = ring.MultiplicativeIdentity();
36		else
37			in >> coef;
38
39		std::ws(in);
40		if (in.peek() == 'x')
41		{
42			in.get();
43			std::ws(in);
44			if (in.peek() == '^')
45			{
46				in.get();
47				in >> power;
48			}
49			else
50				power = 1;
51		}
52		else
53			power = 0;
54
55		if (!positive)
56			coef = ring.Inverse(coef);
57
58		SetCoefficient(power, coef, ring);
59
60		std::ws(in);
61		switch (in.get())
62		{
63		case '+':
64			positive = true;
65			break;
66		case '-':
67			positive = false;
68			break;
69		default:
70			return;		// something's wrong with the input string
71		}
72	}
73}
74
75template <class T>
76unsigned int PolynomialOver<T>::CoefficientCount(const Ring &ring) const
77{
78	unsigned count = m_coefficients.size();
79	while (count && ring.Equal(m_coefficients[count-1], ring.Identity()))
80		count--;
81	const_cast<std::vector<CoefficientType> &>(m_coefficients).resize(count);
82	return count;
83}
84
85template <class T>
86typename PolynomialOver<T>::CoefficientType PolynomialOver<T>::GetCoefficient(unsigned int i, const Ring &ring) const
87{
88	return (i < m_coefficients.size()) ? m_coefficients[i] : ring.Identity();
89}
90
91template <class T>
92PolynomialOver<T>&  PolynomialOver<T>::operator=(const PolynomialOver<T>& t)
93{
94	if (this != &t)
95	{
96		m_coefficients.resize(t.m_coefficients.size());
97		for (unsigned int i=0; i<m_coefficients.size(); i++)
98			m_coefficients[i] = t.m_coefficients[i];
99	}
100	return *this;
101}
102
103template <class T>
104PolynomialOver<T>& PolynomialOver<T>::Accumulate(const PolynomialOver<T>& t, const Ring &ring)
105{
106	unsigned int count = t.CoefficientCount(ring);
107
108	if (count > CoefficientCount(ring))
109		m_coefficients.resize(count, ring.Identity());
110
111	for (unsigned int i=0; i<count; i++)
112		ring.Accumulate(m_coefficients[i], t.GetCoefficient(i, ring));
113
114	return *this;
115}
116
117template <class T>
118PolynomialOver<T>& PolynomialOver<T>::Reduce(const PolynomialOver<T>& t, const Ring &ring)
119{
120	unsigned int count = t.CoefficientCount(ring);
121
122	if (count > CoefficientCount(ring))
123		m_coefficients.resize(count, ring.Identity());
124
125	for (unsigned int i=0; i<count; i++)
126		ring.Reduce(m_coefficients[i], t.GetCoefficient(i, ring));
127
128	return *this;
129}
130
131template <class T>
132typename PolynomialOver<T>::CoefficientType PolynomialOver<T>::EvaluateAt(const CoefficientType &x, const Ring &ring) const
133{
134	int degree = Degree(ring);
135
136	if (degree < 0)
137		return ring.Identity();
138
139	CoefficientType result = m_coefficients[degree];
140	for (int j=degree-1; j>=0; j--)
141	{
142		result = ring.Multiply(result, x);
143		ring.Accumulate(result, m_coefficients[j]);
144	}
145	return result;
146}
147
148template <class T>
149PolynomialOver<T>& PolynomialOver<T>::ShiftLeft(unsigned int n, const Ring &ring)
150{
151	unsigned int i = CoefficientCount(ring) + n;
152	m_coefficients.resize(i, ring.Identity());
153	while (i > n)
154	{
155		i--;
156		m_coefficients[i] = m_coefficients[i-n];
157	}
158	while (i)
159	{
160		i--;
161		m_coefficients[i] = ring.Identity();
162	}
163	return *this;
164}
165
166template <class T>
167PolynomialOver<T>& PolynomialOver<T>::ShiftRight(unsigned int n, const Ring &ring)
168{
169	unsigned int count = CoefficientCount(ring);
170	if (count > n)
171	{
172		for (unsigned int i=0; i<count-n; i++)
173			m_coefficients[i] = m_coefficients[i+n];
174		m_coefficients.resize(count-n, ring.Identity());
175	}
176	else
177		m_coefficients.resize(0, ring.Identity());
178	return *this;
179}
180
181template <class T>
182void PolynomialOver<T>::SetCoefficient(unsigned int i, const CoefficientType &value, const Ring &ring)
183{
184	if (i >= m_coefficients.size())
185		m_coefficients.resize(i+1, ring.Identity());
186	m_coefficients[i] = value;
187}
188
189template <class T>
190void PolynomialOver<T>::Negate(const Ring &ring)
191{
192	unsigned int count = CoefficientCount(ring);
193	for (unsigned int i=0; i<count; i++)
194		m_coefficients[i] = ring.Inverse(m_coefficients[i]);
195}
196
197template <class T>
198void PolynomialOver<T>::swap(PolynomialOver<T> &t)
199{
200	m_coefficients.swap(t.m_coefficients);
201}
202
203template <class T>
204bool PolynomialOver<T>::Equals(const PolynomialOver<T>& t, const Ring &ring) const
205{
206	unsigned int count = CoefficientCount(ring);
207
208	if (count != t.CoefficientCount(ring))
209		return false;
210
211	for (unsigned int i=0; i<count; i++)
212		if (!ring.Equal(m_coefficients[i], t.m_coefficients[i]))
213			return false;
214
215	return true;
216}
217
218template <class T>
219PolynomialOver<T> PolynomialOver<T>::Plus(const PolynomialOver<T>& t, const Ring &ring) const
220{
221	unsigned int i;
222	unsigned int count = CoefficientCount(ring);
223	unsigned int tCount = t.CoefficientCount(ring);
224
225	if (count > tCount)
226	{
227		PolynomialOver<T> result(ring, count);
228
229		for (i=0; i<tCount; i++)
230			result.m_coefficients[i] = ring.Add(m_coefficients[i], t.m_coefficients[i]);
231		for (; i<count; i++)
232			result.m_coefficients[i] = m_coefficients[i];
233
234		return result;
235	}
236	else
237	{
238		PolynomialOver<T> result(ring, tCount);
239
240		for (i=0; i<count; i++)
241			result.m_coefficients[i] = ring.Add(m_coefficients[i], t.m_coefficients[i]);
242		for (; i<tCount; i++)
243			result.m_coefficients[i] = t.m_coefficients[i];
244
245		return result;
246	}
247}
248
249template <class T>
250PolynomialOver<T> PolynomialOver<T>::Minus(const PolynomialOver<T>& t, const Ring &ring) const
251{
252	unsigned int i;
253	unsigned int count = CoefficientCount(ring);
254	unsigned int tCount = t.CoefficientCount(ring);
255
256	if (count > tCount)
257	{
258		PolynomialOver<T> result(ring, count);
259
260		for (i=0; i<tCount; i++)
261			result.m_coefficients[i] = ring.Subtract(m_coefficients[i], t.m_coefficients[i]);
262		for (; i<count; i++)
263			result.m_coefficients[i] = m_coefficients[i];
264
265		return result;
266	}
267	else
268	{
269		PolynomialOver<T> result(ring, tCount);
270
271		for (i=0; i<count; i++)
272			result.m_coefficients[i] = ring.Subtract(m_coefficients[i], t.m_coefficients[i]);
273		for (; i<tCount; i++)
274			result.m_coefficients[i] = ring.Inverse(t.m_coefficients[i]);
275
276		return result;
277	}
278}
279
280template <class T>
281PolynomialOver<T> PolynomialOver<T>::Inverse(const Ring &ring) const
282{
283	unsigned int count = CoefficientCount(ring);
284	PolynomialOver<T> result(ring, count);
285
286	for (unsigned int i=0; i<count; i++)
287		result.m_coefficients[i] = ring.Inverse(m_coefficients[i]);
288
289	return result;
290}
291
292template <class T>
293PolynomialOver<T> PolynomialOver<T>::Times(const PolynomialOver<T>& t, const Ring &ring) const
294{
295	if (IsZero(ring) || t.IsZero(ring))
296		return PolynomialOver<T>();
297
298	unsigned int count1 = CoefficientCount(ring), count2 = t.CoefficientCount(ring);
299	PolynomialOver<T> result(ring, count1 + count2 - 1);
300
301	for (unsigned int i=0; i<count1; i++)
302		for (unsigned int j=0; j<count2; j++)
303			ring.Accumulate(result.m_coefficients[i+j], ring.Multiply(m_coefficients[i], t.m_coefficients[j]));
304
305	return result;
306}
307
308template <class T>
309PolynomialOver<T> PolynomialOver<T>::DividedBy(const PolynomialOver<T>& t, const Ring &ring) const
310{
311	PolynomialOver<T> remainder, quotient;
312	Divide(remainder, quotient, *this, t, ring);
313	return quotient;
314}
315
316template <class T>
317PolynomialOver<T> PolynomialOver<T>::Modulo(const PolynomialOver<T>& t, const Ring &ring) const
318{
319	PolynomialOver<T> remainder, quotient;
320	Divide(remainder, quotient, *this, t, ring);
321	return remainder;
322}
323
324template <class T>
325PolynomialOver<T> PolynomialOver<T>::MultiplicativeInverse(const Ring &ring) const
326{
327	return Degree(ring)==0 ? ring.MultiplicativeInverse(m_coefficients[0]) : ring.Identity();
328}
329
330template <class T>
331bool PolynomialOver<T>::IsUnit(const Ring &ring) const
332{
333	return Degree(ring)==0 && ring.IsUnit(m_coefficients[0]);
334}
335
336template <class T>
337std::istream& PolynomialOver<T>::Input(std::istream &in, const Ring &ring)
338{
339	char c;
340	unsigned int length = 0;
341	SecBlock<char> str(length + 16);
342	bool paren = false;
343
344	std::ws(in);
345
346	if (in.peek() == '(')
347	{
348		paren = true;
349		in.get();
350	}
351
352	do
353	{
354		in.read(&c, 1);
355		str[length++] = c;
356		if (length >= str.size())
357			str.Grow(length + 16);
358	}
359	// if we started with a left paren, then read until we find a right paren,
360	// otherwise read until the end of the line
361	while (in && ((paren && c != ')') || (!paren && c != '\n')));
362
363	str[length-1] = '\0';
364	*this = PolynomialOver<T>(str, ring);
365
366	return in;
367}
368
369template <class T>
370std::ostream& PolynomialOver<T>::Output(std::ostream &out, const Ring &ring) const
371{
372	unsigned int i = CoefficientCount(ring);
373	if (i)
374	{
375		bool firstTerm = true;
376
377		while (i--)
378		{
379			if (m_coefficients[i] != ring.Identity())
380			{
381				if (firstTerm)
382				{
383					firstTerm = false;
384					if (!i || !ring.Equal(m_coefficients[i], ring.MultiplicativeIdentity()))
385						out << m_coefficients[i];
386				}
387				else
388				{
389					CoefficientType inverse = ring.Inverse(m_coefficients[i]);
390					std::ostringstream pstr, nstr;
391
392					pstr << m_coefficients[i];
393					nstr << inverse;
394
395					if (pstr.str().size() <= nstr.str().size())
396					{
397						out << " + ";
398						if (!i || !ring.Equal(m_coefficients[i], ring.MultiplicativeIdentity()))
399							out << m_coefficients[i];
400					}
401					else
402					{
403						out << " - ";
404						if (!i || !ring.Equal(inverse, ring.MultiplicativeIdentity()))
405							out << inverse;
406					}
407				}
408
409				switch (i)
410				{
411				case 0:
412					break;
413				case 1:
414					out << "x";
415					break;
416				default:
417					out << "x^" << i;
418				}
419			}
420		}
421	}
422	else
423	{
424		out << ring.Identity();
425	}
426	return out;
427}
428
429template <class T>
430void PolynomialOver<T>::Divide(PolynomialOver<T> &r, PolynomialOver<T> &q, const PolynomialOver<T> &a, const PolynomialOver<T> &d, const Ring &ring)
431{
432	unsigned int i = a.CoefficientCount(ring);
433	const int dDegree = d.Degree(ring);
434
435	if (dDegree < 0)
436		throw DivideByZero();
437
438	r = a;
439	q.m_coefficients.resize(STDMAX(0, int(i - dDegree)));
440
441	while (i > (unsigned int)dDegree)
442	{
443		--i;
444		q.m_coefficients[i-dDegree] = ring.Divide(r.m_coefficients[i], d.m_coefficients[dDegree]);
445		for (int j=0; j<=dDegree; j++)
446			ring.Reduce(r.m_coefficients[i-dDegree+j], ring.Multiply(q.m_coefficients[i-dDegree], d.m_coefficients[j]));
447	}
448
449	r.CoefficientCount(ring);	// resize r.m_coefficients
450}
451
452// ********************************************************
453
454// helper function for Interpolate() and InterpolateAt()
455template <class T>
456void RingOfPolynomialsOver<T>::CalculateAlpha(std::vector<CoefficientType> &alpha, const CoefficientType x[], const CoefficientType y[], unsigned int n) const
457{
458	for (unsigned int j=0; j<n; ++j)
459		alpha[j] = y[j];
460
461	for (unsigned int k=1; k<n; ++k)
462	{
463		for (unsigned int j=n-1; j>=k; --j)
464		{
465			m_ring.Reduce(alpha[j], alpha[j-1]);
466
467			CoefficientType d = m_ring.Subtract(x[j], x[j-k]);
468			if (!m_ring.IsUnit(d))
469				throw InterpolationFailed();
470			alpha[j] = m_ring.Divide(alpha[j], d);
471		}
472	}
473}
474
475template <class T>
476typename RingOfPolynomialsOver<T>::Element RingOfPolynomialsOver<T>::Interpolate(const CoefficientType x[], const CoefficientType y[], unsigned int n) const
477{
478	assert(n > 0);
479
480	std::vector<CoefficientType> alpha(n);
481	CalculateAlpha(alpha, x, y, n);
482
483	std::vector<CoefficientType> coefficients((size_t)n, m_ring.Identity());
484	coefficients[0] = alpha[n-1];
485
486	for (int j=n-2; j>=0; --j)
487	{
488		for (unsigned int i=n-j-1; i>0; i--)
489			coefficients[i] = m_ring.Subtract(coefficients[i-1], m_ring.Multiply(coefficients[i], x[j]));
490
491		coefficients[0] = m_ring.Subtract(alpha[j], m_ring.Multiply(coefficients[0], x[j]));
492	}
493
494	return PolynomialOver<T>(coefficients.begin(), coefficients.end());
495}
496
497template <class T>
498typename RingOfPolynomialsOver<T>::CoefficientType RingOfPolynomialsOver<T>::InterpolateAt(const CoefficientType &position, const CoefficientType x[], const CoefficientType y[], unsigned int n) const
499{
500	assert(n > 0);
501
502	std::vector<CoefficientType> alpha(n);
503	CalculateAlpha(alpha, x, y, n);
504
505	CoefficientType result = alpha[n-1];
506	for (int j=n-2; j>=0; --j)
507	{
508		result = m_ring.Multiply(result, m_ring.Subtract(position, x[j]));
509		m_ring.Accumulate(result, alpha[j]);
510	}
511	return result;
512}
513
514template <class Ring, class Element>
515void PrepareBulkPolynomialInterpolation(const Ring &ring, Element *w, const Element x[], unsigned int n)
516{
517	for (unsigned int i=0; i<n; i++)
518	{
519		Element t = ring.MultiplicativeIdentity();
520		for (unsigned int j=0; j<n; j++)
521			if (i != j)
522				t = ring.Multiply(t, ring.Subtract(x[i], x[j]));
523		w[i] = ring.MultiplicativeInverse(t);
524	}
525}
526
527template <class Ring, class Element>
528void PrepareBulkPolynomialInterpolationAt(const Ring &ring, Element *v, const Element &position, const Element x[], const Element w[], unsigned int n)
529{
530	assert(n > 0);
531
532	std::vector<Element> a(2*n-1);
533	unsigned int i;
534
535	for (i=0; i<n; i++)
536		a[n-1+i] = ring.Subtract(position, x[i]);
537
538	for (i=n-1; i>1; i--)
539		a[i-1] = ring.Multiply(a[2*i], a[2*i-1]);
540
541	a[0] = ring.MultiplicativeIdentity();
542
543	for (i=0; i<n-1; i++)
544	{
545		std::swap(a[2*i+1], a[2*i+2]);
546		a[2*i+1] = ring.Multiply(a[i], a[2*i+1]);
547		a[2*i+2] = ring.Multiply(a[i], a[2*i+2]);
548	}
549
550	for (i=0; i<n; i++)
551		v[i] = ring.Multiply(a[n-1+i], w[i]);
552}
553
554template <class Ring, class Element>
555Element BulkPolynomialInterpolateAt(const Ring &ring, const Element y[], const Element v[], unsigned int n)
556{
557	Element result = ring.Identity();
558	for (unsigned int i=0; i<n; i++)
559		ring.Accumulate(result, ring.Multiply(y[i], v[i]));
560	return result;
561}
562
563// ********************************************************
564
565template <class T, int instance>
566const PolynomialOverFixedRing<T, instance> &PolynomialOverFixedRing<T, instance>::Zero()
567{
568	return Singleton<ThisType>().Ref();
569}
570
571template <class T, int instance>
572const PolynomialOverFixedRing<T, instance> &PolynomialOverFixedRing<T, instance>::One()
573{
574	return Singleton<ThisType, NewOnePolynomial>().Ref();
575}
576
577NAMESPACE_END
578