1// ida.cpp - written and placed in the public domain by Wei Dai 2 3#include "pch.h" 4#include "ida.h" 5 6#include "algebra.h" 7#include "gf2_32.h" 8#include "polynomi.h" 9#include <functional> 10 11#include "polynomi.cpp" 12 13ANONYMOUS_NAMESPACE_BEGIN 14static const CryptoPP::GF2_32 field; 15NAMESPACE_END 16 17using namespace std; 18 19NAMESPACE_BEGIN(CryptoPP) 20 21void RawIDA::IsolatedInitialize(const NameValuePairs ¶meters) 22{ 23 if (!parameters.GetIntValue("RecoveryThreshold", m_threshold)) 24 throw InvalidArgument("RawIDA: missing RecoveryThreshold argument"); 25 26 if (m_threshold <= 0) 27 throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0"); 28 29 m_lastMapPosition = m_inputChannelMap.end(); 30 m_channelsReady = 0; 31 m_channelsFinished = 0; 32 m_w.New(m_threshold); 33 m_y.New(m_threshold); 34 m_inputQueues.reserve(m_threshold); 35 36 m_outputChannelIds.clear(); 37 m_outputChannelIdStrings.clear(); 38 m_outputQueues.clear(); 39 40 word32 outputChannelID; 41 if (parameters.GetValue("OutputChannelID", outputChannelID)) 42 AddOutputChannel(outputChannelID); 43 else 44 { 45 int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold); 46 for (int i=0; i<nShares; i++) 47 AddOutputChannel(i); 48 } 49} 50 51unsigned int RawIDA::InsertInputChannel(word32 channelId) 52{ 53 if (m_lastMapPosition != m_inputChannelMap.end()) 54 { 55 if (m_lastMapPosition->first == channelId) 56 goto skipFind; 57 ++m_lastMapPosition; 58 if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId) 59 goto skipFind; 60 } 61 m_lastMapPosition = m_inputChannelMap.find(channelId); 62 63skipFind: 64 if (m_lastMapPosition == m_inputChannelMap.end()) 65 { 66 if (m_inputChannelIds.size() == m_threshold) 67 return m_threshold; 68 69 m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first; 70 m_inputQueues.push_back(MessageQueue()); 71 m_inputChannelIds.push_back(channelId); 72 73 if (m_inputChannelIds.size() == m_threshold) 74 PrepareInterpolation(); 75 } 76 return m_lastMapPosition->second; 77} 78 79unsigned int RawIDA::LookupInputChannel(word32 channelId) const 80{ 81 map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId); 82 if (it == m_inputChannelMap.end()) 83 return m_threshold; 84 else 85 return it->second; 86} 87 88void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd) 89{ 90 int i = InsertInputChannel(channelId); 91 if (i < m_threshold) 92 { 93 lword size = m_inputQueues[i].MaxRetrievable(); 94 m_inputQueues[i].Put(inString, length); 95 if (size < 4 && size + length >= 4) 96 { 97 m_channelsReady++; 98 if (m_channelsReady == m_threshold) 99 ProcessInputQueues(); 100 } 101 102 if (messageEnd) 103 { 104 m_inputQueues[i].MessageEnd(); 105 if (m_inputQueues[i].NumberOfMessages() == 1) 106 { 107 m_channelsFinished++; 108 if (m_channelsFinished == m_threshold) 109 { 110 m_channelsReady = 0; 111 for (i=0; i<m_threshold; i++) 112 m_channelsReady += m_inputQueues[i].AnyRetrievable(); 113 ProcessInputQueues(); 114 } 115 } 116 } 117 } 118} 119 120lword RawIDA::InputBuffered(word32 channelId) const 121{ 122 int i = LookupInputChannel(channelId); 123 return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0; 124} 125 126void RawIDA::ComputeV(unsigned int i) 127{ 128 if (i >= m_v.size()) 129 { 130 m_v.resize(i+1); 131 m_outputToInput.resize(i+1); 132 } 133 134 m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]); 135 if (m_outputToInput[i] == m_threshold && i * m_threshold <= 1000*1000) 136 { 137 m_v[i].resize(m_threshold); 138 PrepareBulkPolynomialInterpolationAt(field, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold); 139 } 140} 141 142void RawIDA::AddOutputChannel(word32 channelId) 143{ 144 m_outputChannelIds.push_back(channelId); 145 m_outputChannelIdStrings.push_back(WordToString(channelId)); 146 m_outputQueues.push_back(ByteQueue()); 147 if (m_inputChannelIds.size() == m_threshold) 148 ComputeV((unsigned int)m_outputChannelIds.size() - 1); 149} 150 151void RawIDA::PrepareInterpolation() 152{ 153 assert(m_inputChannelIds.size() == m_threshold); 154 PrepareBulkPolynomialInterpolation(field, m_w.begin(), &(m_inputChannelIds[0]), m_threshold); 155 for (unsigned int i=0; i<m_outputChannelIds.size(); i++) 156 ComputeV(i); 157} 158 159void RawIDA::ProcessInputQueues() 160{ 161 bool finished = (m_channelsFinished == m_threshold); 162 int i; 163 164 while (finished ? m_channelsReady > 0 : m_channelsReady == m_threshold) 165 { 166 m_channelsReady = 0; 167 for (i=0; i<m_threshold; i++) 168 { 169 MessageQueue &queue = m_inputQueues[i]; 170 queue.GetWord32(m_y[i]); 171 172 if (finished) 173 m_channelsReady += queue.AnyRetrievable(); 174 else 175 m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4; 176 } 177 178 for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++) 179 { 180 if (m_outputToInput[i] != m_threshold) 181 m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]); 182 else if (m_v[i].size() == m_threshold) 183 m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_v[i].begin(), m_threshold)); 184 else 185 { 186 m_u.resize(m_threshold); 187 PrepareBulkPolynomialInterpolationAt(field, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold); 188 m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_u.begin(), m_threshold)); 189 } 190 } 191 } 192 193 if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable()) 194 FlushOutputQueues(); 195 196 if (finished) 197 { 198 OutputMessageEnds(); 199 200 m_channelsReady = 0; 201 m_channelsFinished = 0; 202 m_v.clear(); 203 204 vector<MessageQueue> inputQueues; 205 vector<word32> inputChannelIds; 206 207 inputQueues.swap(m_inputQueues); 208 inputChannelIds.swap(m_inputChannelIds); 209 m_inputChannelMap.clear(); 210 m_lastMapPosition = m_inputChannelMap.end(); 211 212 for (i=0; i<m_threshold; i++) 213 { 214 inputQueues[i].GetNextMessage(); 215 inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i])); 216 } 217 } 218} 219 220void RawIDA::FlushOutputQueues() 221{ 222 for (unsigned int i=0; i<m_outputChannelIds.size(); i++) 223 m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]); 224} 225 226void RawIDA::OutputMessageEnds() 227{ 228 if (GetAutoSignalPropagation() != 0) 229 { 230 for (unsigned int i=0; i<m_outputChannelIds.size(); i++) 231 AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1); 232 } 233} 234 235// **************************************************************** 236 237void SecretSharing::IsolatedInitialize(const NameValuePairs ¶meters) 238{ 239 m_pad = parameters.GetValueWithDefault("AddPadding", true); 240 m_ida.IsolatedInitialize(parameters); 241} 242 243size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking) 244{ 245 if (!blocking) 246 throw BlockingInputOnly("SecretSharing"); 247 248 SecByteBlock buf(UnsignedMin(256, length)); 249 unsigned int threshold = m_ida.GetThreshold(); 250 while (length > 0) 251 { 252 size_t len = STDMIN(length, buf.size()); 253 m_ida.ChannelData(0xffffffff, begin, len, false); 254 for (unsigned int i=0; i<threshold-1; i++) 255 { 256 m_rng.GenerateBlock(buf, len); 257 m_ida.ChannelData(i, buf, len, false); 258 } 259 length -= len; 260 begin += len; 261 } 262 263 if (messageEnd) 264 { 265 m_ida.SetAutoSignalPropagation(messageEnd-1); 266 if (m_pad) 267 { 268 SecretSharing::Put(1); 269 while (m_ida.InputBuffered(0xffffffff) > 0) 270 SecretSharing::Put(0); 271 } 272 m_ida.ChannelData(0xffffffff, NULL, 0, true); 273 for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++) 274 m_ida.ChannelData(i, NULL, 0, true); 275 } 276 277 return 0; 278} 279 280void SecretRecovery::IsolatedInitialize(const NameValuePairs ¶meters) 281{ 282 m_pad = parameters.GetValueWithDefault("RemovePadding", true); 283 RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff))); 284} 285 286void SecretRecovery::FlushOutputQueues() 287{ 288 if (m_pad) 289 m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4); 290 else 291 m_outputQueues[0].TransferTo(*AttachedTransformation()); 292} 293 294void SecretRecovery::OutputMessageEnds() 295{ 296 if (m_pad) 297 { 298 PaddingRemover paddingRemover(new Redirector(*AttachedTransformation())); 299 m_outputQueues[0].TransferAllTo(paddingRemover); 300 } 301 302 if (GetAutoSignalPropagation() != 0) 303 AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1); 304} 305 306// **************************************************************** 307 308void InformationDispersal::IsolatedInitialize(const NameValuePairs ¶meters) 309{ 310 m_nextChannel = 0; 311 m_pad = parameters.GetValueWithDefault("AddPadding", true); 312 m_ida.IsolatedInitialize(parameters); 313} 314 315size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking) 316{ 317 if (!blocking) 318 throw BlockingInputOnly("InformationDispersal"); 319 320 while (length--) 321 { 322 m_ida.ChannelData(m_nextChannel, begin, 1, false); 323 begin++; 324 m_nextChannel++; 325 if (m_nextChannel == m_ida.GetThreshold()) 326 m_nextChannel = 0; 327 } 328 329 if (messageEnd) 330 { 331 m_ida.SetAutoSignalPropagation(messageEnd-1); 332 if (m_pad) 333 InformationDispersal::Put(1); 334 for (word32 i=0; i<m_ida.GetThreshold(); i++) 335 m_ida.ChannelData(i, NULL, 0, true); 336 } 337 338 return 0; 339} 340 341void InformationRecovery::IsolatedInitialize(const NameValuePairs ¶meters) 342{ 343 m_pad = parameters.GetValueWithDefault("RemovePadding", true); 344 RawIDA::IsolatedInitialize(parameters); 345} 346 347void InformationRecovery::FlushOutputQueues() 348{ 349 while (m_outputQueues[0].AnyRetrievable()) 350 { 351 for (unsigned int i=0; i<m_outputChannelIds.size(); i++) 352 m_outputQueues[i].TransferTo(m_queue, 1); 353 } 354 355 if (m_pad) 356 m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold); 357 else 358 m_queue.TransferTo(*AttachedTransformation()); 359} 360 361void InformationRecovery::OutputMessageEnds() 362{ 363 if (m_pad) 364 { 365 PaddingRemover paddingRemover(new Redirector(*AttachedTransformation())); 366 m_queue.TransferAllTo(paddingRemover); 367 } 368 369 if (GetAutoSignalPropagation() != 0) 370 AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1); 371} 372 373size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking) 374{ 375 if (!blocking) 376 throw BlockingInputOnly("PaddingRemover"); 377 378 const byte *const end = begin + length; 379 380 if (m_possiblePadding) 381 { 382 size_t len = find_if(begin, end, bind2nd(not_equal_to<byte>(), 0)) - begin; 383 m_zeroCount += len; 384 begin += len; 385 if (begin == end) 386 return 0; 387 388 AttachedTransformation()->Put(1); 389 while (m_zeroCount--) 390 AttachedTransformation()->Put(0); 391 AttachedTransformation()->Put(*begin++); 392 m_possiblePadding = false; 393 } 394 395#if defined(_MSC_VER) && !defined(__MWERKS__) && (_MSC_VER <= 1300) 396 // VC60 and VC7 workaround: built-in reverse_iterator has two template parameters, Dinkumware only has one 397 typedef reverse_bidirectional_iterator<const byte *, const byte> RevIt; 398#elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC) 399 typedef reverse_iterator<const byte *, random_access_iterator_tag, const byte> RevIt; 400#else 401 typedef reverse_iterator<const byte *> RevIt; 402#endif 403 const byte *x = find_if(RevIt(end), RevIt(begin), bind2nd(not_equal_to<byte>(), 0)).base(); 404 if (x != begin && *(x-1) == 1) 405 { 406 AttachedTransformation()->Put(begin, x-begin-1); 407 m_possiblePadding = true; 408 m_zeroCount = end - x; 409 } 410 else 411 AttachedTransformation()->Put(begin, end-begin); 412 413 if (messageEnd) 414 { 415 m_possiblePadding = false; 416 Output(0, begin, length, messageEnd, blocking); 417 } 418 return 0; 419} 420 421NAMESPACE_END 422