1// algebra.cpp - written and placed in the public domain by Wei Dai 2 3#include "pch.h" 4 5#ifndef CRYPTOPP_ALGEBRA_CPP // SunCC workaround: compiler could cause this file to be included twice 6#define CRYPTOPP_ALGEBRA_CPP 7 8#include "algebra.h" 9#include "integer.h" 10 11#include <vector> 12 13NAMESPACE_BEGIN(CryptoPP) 14 15template <class T> const T& AbstractGroup<T>::Double(const Element &a) const 16{ 17 return Add(a, a); 18} 19 20template <class T> const T& AbstractGroup<T>::Subtract(const Element &a, const Element &b) const 21{ 22 // make copy of a in case Inverse() overwrites it 23 Element a1(a); 24 return Add(a1, Inverse(b)); 25} 26 27template <class T> T& AbstractGroup<T>::Accumulate(Element &a, const Element &b) const 28{ 29 return a = Add(a, b); 30} 31 32template <class T> T& AbstractGroup<T>::Reduce(Element &a, const Element &b) const 33{ 34 return a = Subtract(a, b); 35} 36 37template <class T> const T& AbstractRing<T>::Square(const Element &a) const 38{ 39 return Multiply(a, a); 40} 41 42template <class T> const T& AbstractRing<T>::Divide(const Element &a, const Element &b) const 43{ 44 // make copy of a in case MultiplicativeInverse() overwrites it 45 Element a1(a); 46 return Multiply(a1, MultiplicativeInverse(b)); 47} 48 49template <class T> const T& AbstractEuclideanDomain<T>::Mod(const Element &a, const Element &b) const 50{ 51 Element q; 52 DivisionAlgorithm(result, q, a, b); 53 return result; 54} 55 56template <class T> const T& AbstractEuclideanDomain<T>::Gcd(const Element &a, const Element &b) const 57{ 58 Element g[3]={b, a}; 59 unsigned int i0=0, i1=1, i2=2; 60 61 while (!Equal(g[i1], this->Identity())) 62 { 63 g[i2] = Mod(g[i0], g[i1]); 64 unsigned int t = i0; i0 = i1; i1 = i2; i2 = t; 65 } 66 67 return result = g[i0]; 68} 69 70template <class T> const typename QuotientRing<T>::Element& QuotientRing<T>::MultiplicativeInverse(const Element &a) const 71{ 72 Element g[3]={m_modulus, a}; 73 Element v[3]={m_domain.Identity(), m_domain.MultiplicativeIdentity()}; 74 Element y; 75 unsigned int i0=0, i1=1, i2=2; 76 77 while (!Equal(g[i1], Identity())) 78 { 79 // y = g[i0] / g[i1]; 80 // g[i2] = g[i0] % g[i1]; 81 m_domain.DivisionAlgorithm(g[i2], y, g[i0], g[i1]); 82 // v[i2] = v[i0] - (v[i1] * y); 83 v[i2] = m_domain.Subtract(v[i0], m_domain.Multiply(v[i1], y)); 84 unsigned int t = i0; i0 = i1; i1 = i2; i2 = t; 85 } 86 87 return m_domain.IsUnit(g[i0]) ? m_domain.Divide(v[i0], g[i0]) : m_domain.Identity(); 88} 89 90template <class T> T AbstractGroup<T>::ScalarMultiply(const Element &base, const Integer &exponent) const 91{ 92 Element result; 93 SimultaneousMultiply(&result, base, &exponent, 1); 94 return result; 95} 96 97template <class T> T AbstractGroup<T>::CascadeScalarMultiply(const Element &x, const Integer &e1, const Element &y, const Integer &e2) const 98{ 99 const unsigned expLen = STDMAX(e1.BitCount(), e2.BitCount()); 100 if (expLen==0) 101 return Identity(); 102 103 const unsigned w = (expLen <= 46 ? 1 : (expLen <= 260 ? 2 : 3)); 104 const unsigned tableSize = 1<<w; 105 std::vector<Element> powerTable(tableSize << w); 106 107 powerTable[1] = x; 108 powerTable[tableSize] = y; 109 if (w==1) 110 powerTable[3] = Add(x,y); 111 else 112 { 113 powerTable[2] = Double(x); 114 powerTable[2*tableSize] = Double(y); 115 116 unsigned i, j; 117 118 for (i=3; i<tableSize; i+=2) 119 powerTable[i] = Add(powerTable[i-2], powerTable[2]); 120 for (i=1; i<tableSize; i+=2) 121 for (j=i+tableSize; j<(tableSize<<w); j+=tableSize) 122 powerTable[j] = Add(powerTable[j-tableSize], y); 123 124 for (i=3*tableSize; i<(tableSize<<w); i+=2*tableSize) 125 powerTable[i] = Add(powerTable[i-2*tableSize], powerTable[2*tableSize]); 126 for (i=tableSize; i<(tableSize<<w); i+=2*tableSize) 127 for (j=i+2; j<i+tableSize; j+=2) 128 powerTable[j] = Add(powerTable[j-1], x); 129 } 130 131 Element result; 132 unsigned power1 = 0, power2 = 0, prevPosition = expLen-1; 133 bool firstTime = true; 134 135 for (int i = expLen-1; i>=0; i--) 136 { 137 power1 = 2*power1 + e1.GetBit(i); 138 power2 = 2*power2 + e2.GetBit(i); 139 140 if (i==0 || 2*power1 >= tableSize || 2*power2 >= tableSize) 141 { 142 unsigned squaresBefore = prevPosition-i; 143 unsigned squaresAfter = 0; 144 prevPosition = i; 145 while ((power1 || power2) && power1%2 == 0 && power2%2==0) 146 { 147 power1 /= 2; 148 power2 /= 2; 149 squaresBefore--; 150 squaresAfter++; 151 } 152 if (firstTime) 153 { 154 result = powerTable[(power2<<w) + power1]; 155 firstTime = false; 156 } 157 else 158 { 159 while (squaresBefore--) 160 result = Double(result); 161 if (power1 || power2) 162 Accumulate(result, powerTable[(power2<<w) + power1]); 163 } 164 while (squaresAfter--) 165 result = Double(result); 166 power1 = power2 = 0; 167 } 168 } 169 return result; 170} 171 172template <class Element, class Iterator> Element GeneralCascadeMultiplication(const AbstractGroup<Element> &group, Iterator begin, Iterator end) 173{ 174 if (end-begin == 1) 175 return group.ScalarMultiply(begin->base, begin->exponent); 176 else if (end-begin == 2) 177 return group.CascadeScalarMultiply(begin->base, begin->exponent, (begin+1)->base, (begin+1)->exponent); 178 else 179 { 180 Integer q, t; 181 Iterator last = end; 182 --last; 183 184 std::make_heap(begin, end); 185 std::pop_heap(begin, end); 186 187 while (!!begin->exponent) 188 { 189 // last->exponent is largest exponent, begin->exponent is next largest 190 t = last->exponent; 191 Integer::Divide(last->exponent, q, t, begin->exponent); 192 193 if (q == Integer::One()) 194 group.Accumulate(begin->base, last->base); // avoid overhead of ScalarMultiply() 195 else 196 group.Accumulate(begin->base, group.ScalarMultiply(last->base, q)); 197 198 std::push_heap(begin, end); 199 std::pop_heap(begin, end); 200 } 201 202 return group.ScalarMultiply(last->base, last->exponent); 203 } 204} 205 206struct WindowSlider 207{ 208 WindowSlider(const Integer &expIn, bool fastNegate, unsigned int windowSizeIn=0) 209 : exp(expIn), windowModulus(Integer::One()), windowSize(windowSizeIn), windowBegin(0), fastNegate(fastNegate), firstTime(true), finished(false) 210 { 211 if (windowSize == 0) 212 { 213 unsigned int expLen = exp.BitCount(); 214 windowSize = expLen <= 17 ? 1 : (expLen <= 24 ? 2 : (expLen <= 70 ? 3 : (expLen <= 197 ? 4 : (expLen <= 539 ? 5 : (expLen <= 1434 ? 6 : 7))))); 215 } 216 windowModulus <<= windowSize; 217 } 218 219 void FindNextWindow() 220 { 221 unsigned int expLen = exp.WordCount() * WORD_BITS; 222 unsigned int skipCount = firstTime ? 0 : windowSize; 223 firstTime = false; 224 while (!exp.GetBit(skipCount)) 225 { 226 if (skipCount >= expLen) 227 { 228 finished = true; 229 return; 230 } 231 skipCount++; 232 } 233 234 exp >>= skipCount; 235 windowBegin += skipCount; 236 expWindow = word32(exp % (word(1) << windowSize)); 237 238 if (fastNegate && exp.GetBit(windowSize)) 239 { 240 negateNext = true; 241 expWindow = (word32(1) << windowSize) - expWindow; 242 exp += windowModulus; 243 } 244 else 245 negateNext = false; 246 } 247 248 Integer exp, windowModulus; 249 unsigned int windowSize, windowBegin; 250 word32 expWindow; 251 bool fastNegate, negateNext, firstTime, finished; 252}; 253 254template <class T> 255void AbstractGroup<T>::SimultaneousMultiply(T *results, const T &base, const Integer *expBegin, unsigned int expCount) const 256{ 257 std::vector<std::vector<Element> > buckets(expCount); 258 std::vector<WindowSlider> exponents; 259 exponents.reserve(expCount); 260 unsigned int i; 261 262 for (i=0; i<expCount; i++) 263 { 264 assert(expBegin->NotNegative()); 265 exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 0)); 266 exponents[i].FindNextWindow(); 267 buckets[i].resize(1<<(exponents[i].windowSize-1), Identity()); 268 } 269 270 unsigned int expBitPosition = 0; 271 Element g = base; 272 bool notDone = true; 273 274 while (notDone) 275 { 276 notDone = false; 277 for (i=0; i<expCount; i++) 278 { 279 if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin) 280 { 281 Element &bucket = buckets[i][exponents[i].expWindow/2]; 282 if (exponents[i].negateNext) 283 Accumulate(bucket, Inverse(g)); 284 else 285 Accumulate(bucket, g); 286 exponents[i].FindNextWindow(); 287 } 288 notDone = notDone || !exponents[i].finished; 289 } 290 291 if (notDone) 292 { 293 g = Double(g); 294 expBitPosition++; 295 } 296 } 297 298 for (i=0; i<expCount; i++) 299 { 300 Element &r = *results++; 301 r = buckets[i][buckets[i].size()-1]; 302 if (buckets[i].size() > 1) 303 { 304 for (int j = (int)buckets[i].size()-2; j >= 1; j--) 305 { 306 Accumulate(buckets[i][j], buckets[i][j+1]); 307 Accumulate(r, buckets[i][j]); 308 } 309 Accumulate(buckets[i][0], buckets[i][1]); 310 r = Add(Double(r), buckets[i][0]); 311 } 312 } 313} 314 315template <class T> T AbstractRing<T>::Exponentiate(const Element &base, const Integer &exponent) const 316{ 317 Element result; 318 SimultaneousExponentiate(&result, base, &exponent, 1); 319 return result; 320} 321 322template <class T> T AbstractRing<T>::CascadeExponentiate(const Element &x, const Integer &e1, const Element &y, const Integer &e2) const 323{ 324 return MultiplicativeGroup().AbstractGroup<T>::CascadeScalarMultiply(x, e1, y, e2); 325} 326 327template <class Element, class Iterator> Element GeneralCascadeExponentiation(const AbstractRing<Element> &ring, Iterator begin, Iterator end) 328{ 329 return GeneralCascadeMultiplication<Element>(ring.MultiplicativeGroup(), begin, end); 330} 331 332template <class T> 333void AbstractRing<T>::SimultaneousExponentiate(T *results, const T &base, const Integer *exponents, unsigned int expCount) const 334{ 335 MultiplicativeGroup().AbstractGroup<T>::SimultaneousMultiply(results, base, exponents, expCount); 336} 337 338NAMESPACE_END 339 340#endif 341