1// ida.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4#include "ida.h"
5
6#include "algebra.h"
7#include "gf2_32.h"
8#include "polynomi.h"
9#include <functional>
10
11#include "polynomi.cpp"
12
13ANONYMOUS_NAMESPACE_BEGIN
14static const CryptoPP::GF2_32 field;
15NAMESPACE_END
16
17using namespace std;
18
19NAMESPACE_BEGIN(CryptoPP)
20
21void RawIDA::IsolatedInitialize(const NameValuePairs &parameters)
22{
23	if (!parameters.GetIntValue("RecoveryThreshold", m_threshold))
24		throw InvalidArgument("RawIDA: missing RecoveryThreshold argument");
25
26	if (m_threshold <= 0)
27		throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0");
28
29	m_lastMapPosition = m_inputChannelMap.end();
30	m_channelsReady = 0;
31	m_channelsFinished = 0;
32	m_w.New(m_threshold);
33	m_y.New(m_threshold);
34	m_inputQueues.reserve(m_threshold);
35
36	m_outputChannelIds.clear();
37	m_outputChannelIdStrings.clear();
38	m_outputQueues.clear();
39
40	word32 outputChannelID;
41	if (parameters.GetValue("OutputChannelID", outputChannelID))
42		AddOutputChannel(outputChannelID);
43	else
44	{
45		int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold);
46		for (int i=0; i<nShares; i++)
47			AddOutputChannel(i);
48	}
49}
50
51unsigned int RawIDA::InsertInputChannel(word32 channelId)
52{
53	if (m_lastMapPosition != m_inputChannelMap.end())
54	{
55		if (m_lastMapPosition->first == channelId)
56			goto skipFind;
57		++m_lastMapPosition;
58		if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId)
59			goto skipFind;
60	}
61	m_lastMapPosition = m_inputChannelMap.find(channelId);
62
63skipFind:
64	if (m_lastMapPosition == m_inputChannelMap.end())
65	{
66		if (m_inputChannelIds.size() == m_threshold)
67			return m_threshold;
68
69		m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first;
70		m_inputQueues.push_back(MessageQueue());
71		m_inputChannelIds.push_back(channelId);
72
73		if (m_inputChannelIds.size() == m_threshold)
74			PrepareInterpolation();
75	}
76	return m_lastMapPosition->second;
77}
78
79unsigned int RawIDA::LookupInputChannel(word32 channelId) const
80{
81	map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId);
82	if (it == m_inputChannelMap.end())
83		return m_threshold;
84	else
85		return it->second;
86}
87
88void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd)
89{
90	int i = InsertInputChannel(channelId);
91	if (i < m_threshold)
92	{
93		lword size = m_inputQueues[i].MaxRetrievable();
94		m_inputQueues[i].Put(inString, length);
95		if (size < 4 && size + length >= 4)
96		{
97			m_channelsReady++;
98			if (m_channelsReady == m_threshold)
99				ProcessInputQueues();
100		}
101
102		if (messageEnd)
103		{
104			m_inputQueues[i].MessageEnd();
105			if (m_inputQueues[i].NumberOfMessages() == 1)
106			{
107				m_channelsFinished++;
108				if (m_channelsFinished == m_threshold)
109				{
110					m_channelsReady = 0;
111					for (i=0; i<m_threshold; i++)
112						m_channelsReady += m_inputQueues[i].AnyRetrievable();
113					ProcessInputQueues();
114				}
115			}
116		}
117	}
118}
119
120lword RawIDA::InputBuffered(word32 channelId) const
121{
122	int i = LookupInputChannel(channelId);
123	return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0;
124}
125
126void RawIDA::ComputeV(unsigned int i)
127{
128	if (i >= m_v.size())
129	{
130		m_v.resize(i+1);
131		m_outputToInput.resize(i+1);
132	}
133
134	m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]);
135	if (m_outputToInput[i] == m_threshold && i * m_threshold <= 1000*1000)
136	{
137		m_v[i].resize(m_threshold);
138		PrepareBulkPolynomialInterpolationAt(field, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
139	}
140}
141
142void RawIDA::AddOutputChannel(word32 channelId)
143{
144	m_outputChannelIds.push_back(channelId);
145	m_outputChannelIdStrings.push_back(WordToString(channelId));
146	m_outputQueues.push_back(ByteQueue());
147	if (m_inputChannelIds.size() == m_threshold)
148		ComputeV((unsigned int)m_outputChannelIds.size() - 1);
149}
150
151void RawIDA::PrepareInterpolation()
152{
153	assert(m_inputChannelIds.size() == m_threshold);
154	PrepareBulkPolynomialInterpolation(field, m_w.begin(), &(m_inputChannelIds[0]), m_threshold);
155	for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
156		ComputeV(i);
157}
158
159void RawIDA::ProcessInputQueues()
160{
161	bool finished = (m_channelsFinished == m_threshold);
162	int i;
163
164	while (finished ? m_channelsReady > 0 : m_channelsReady == m_threshold)
165	{
166		m_channelsReady = 0;
167		for (i=0; i<m_threshold; i++)
168		{
169			MessageQueue &queue = m_inputQueues[i];
170			queue.GetWord32(m_y[i]);
171
172			if (finished)
173				m_channelsReady += queue.AnyRetrievable();
174			else
175				m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4;
176		}
177
178		for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++)
179		{
180			if (m_outputToInput[i] != m_threshold)
181				m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]);
182			else if (m_v[i].size() == m_threshold)
183				m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_v[i].begin(), m_threshold));
184			else
185			{
186				m_u.resize(m_threshold);
187				PrepareBulkPolynomialInterpolationAt(field, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
188				m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_u.begin(), m_threshold));
189			}
190		}
191	}
192
193	if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable())
194		FlushOutputQueues();
195
196	if (finished)
197	{
198		OutputMessageEnds();
199
200		m_channelsReady = 0;
201		m_channelsFinished = 0;
202		m_v.clear();
203
204		vector<MessageQueue> inputQueues;
205		vector<word32> inputChannelIds;
206
207		inputQueues.swap(m_inputQueues);
208		inputChannelIds.swap(m_inputChannelIds);
209		m_inputChannelMap.clear();
210		m_lastMapPosition = m_inputChannelMap.end();
211
212		for (i=0; i<m_threshold; i++)
213		{
214			inputQueues[i].GetNextMessage();
215			inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i]));
216		}
217	}
218}
219
220void RawIDA::FlushOutputQueues()
221{
222	for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
223		m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]);
224}
225
226void RawIDA::OutputMessageEnds()
227{
228	if (GetAutoSignalPropagation() != 0)
229	{
230		for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
231			AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1);
232	}
233}
234
235// ****************************************************************
236
237void SecretSharing::IsolatedInitialize(const NameValuePairs &parameters)
238{
239	m_pad = parameters.GetValueWithDefault("AddPadding", true);
240	m_ida.IsolatedInitialize(parameters);
241}
242
243size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
244{
245	if (!blocking)
246		throw BlockingInputOnly("SecretSharing");
247
248	SecByteBlock buf(UnsignedMin(256, length));
249	unsigned int threshold = m_ida.GetThreshold();
250	while (length > 0)
251	{
252		size_t len = STDMIN(length, buf.size());
253		m_ida.ChannelData(0xffffffff, begin, len, false);
254		for (unsigned int i=0; i<threshold-1; i++)
255		{
256			m_rng.GenerateBlock(buf, len);
257			m_ida.ChannelData(i, buf, len, false);
258		}
259		length -= len;
260		begin += len;
261	}
262
263	if (messageEnd)
264	{
265		m_ida.SetAutoSignalPropagation(messageEnd-1);
266		if (m_pad)
267		{
268			SecretSharing::Put(1);
269			while (m_ida.InputBuffered(0xffffffff) > 0)
270				SecretSharing::Put(0);
271		}
272		m_ida.ChannelData(0xffffffff, NULL, 0, true);
273		for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++)
274			m_ida.ChannelData(i, NULL, 0, true);
275	}
276
277	return 0;
278}
279
280void SecretRecovery::IsolatedInitialize(const NameValuePairs &parameters)
281{
282	m_pad = parameters.GetValueWithDefault("RemovePadding", true);
283	RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff)));
284}
285
286void SecretRecovery::FlushOutputQueues()
287{
288	if (m_pad)
289		m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4);
290	else
291		m_outputQueues[0].TransferTo(*AttachedTransformation());
292}
293
294void SecretRecovery::OutputMessageEnds()
295{
296	if (m_pad)
297	{
298		PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
299		m_outputQueues[0].TransferAllTo(paddingRemover);
300	}
301
302	if (GetAutoSignalPropagation() != 0)
303		AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
304}
305
306// ****************************************************************
307
308void InformationDispersal::IsolatedInitialize(const NameValuePairs &parameters)
309{
310	m_nextChannel = 0;
311	m_pad = parameters.GetValueWithDefault("AddPadding", true);
312	m_ida.IsolatedInitialize(parameters);
313}
314
315size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
316{
317	if (!blocking)
318		throw BlockingInputOnly("InformationDispersal");
319
320	while (length--)
321	{
322		m_ida.ChannelData(m_nextChannel, begin, 1, false);
323		begin++;
324		m_nextChannel++;
325		if (m_nextChannel == m_ida.GetThreshold())
326			m_nextChannel = 0;
327	}
328
329	if (messageEnd)
330	{
331		m_ida.SetAutoSignalPropagation(messageEnd-1);
332		if (m_pad)
333			InformationDispersal::Put(1);
334		for (word32 i=0; i<m_ida.GetThreshold(); i++)
335			m_ida.ChannelData(i, NULL, 0, true);
336	}
337
338	return 0;
339}
340
341void InformationRecovery::IsolatedInitialize(const NameValuePairs &parameters)
342{
343	m_pad = parameters.GetValueWithDefault("RemovePadding", true);
344	RawIDA::IsolatedInitialize(parameters);
345}
346
347void InformationRecovery::FlushOutputQueues()
348{
349	while (m_outputQueues[0].AnyRetrievable())
350	{
351		for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
352			m_outputQueues[i].TransferTo(m_queue, 1);
353	}
354
355	if (m_pad)
356		m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold);
357	else
358		m_queue.TransferTo(*AttachedTransformation());
359}
360
361void InformationRecovery::OutputMessageEnds()
362{
363	if (m_pad)
364	{
365		PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
366		m_queue.TransferAllTo(paddingRemover);
367	}
368
369	if (GetAutoSignalPropagation() != 0)
370		AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
371}
372
373size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
374{
375	if (!blocking)
376		throw BlockingInputOnly("PaddingRemover");
377
378	const byte *const end = begin + length;
379
380	if (m_possiblePadding)
381	{
382		size_t len = find_if(begin, end, bind2nd(not_equal_to<byte>(), 0)) - begin;
383		m_zeroCount += len;
384		begin += len;
385		if (begin == end)
386			return 0;
387
388		AttachedTransformation()->Put(1);
389		while (m_zeroCount--)
390			AttachedTransformation()->Put(0);
391		AttachedTransformation()->Put(*begin++);
392		m_possiblePadding = false;
393	}
394
395#if defined(_MSC_VER) && !defined(__MWERKS__) && (_MSC_VER <= 1300)
396	// VC60 and VC7 workaround: built-in reverse_iterator has two template parameters, Dinkumware only has one
397	typedef reverse_bidirectional_iterator<const byte *, const byte> RevIt;
398#elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC)
399	typedef reverse_iterator<const byte *, random_access_iterator_tag, const byte> RevIt;
400#else
401	typedef reverse_iterator<const byte *> RevIt;
402#endif
403	const byte *x = find_if(RevIt(end), RevIt(begin), bind2nd(not_equal_to<byte>(), 0)).base();
404	if (x != begin && *(x-1) == 1)
405	{
406		AttachedTransformation()->Put(begin, x-begin-1);
407		m_possiblePadding = true;
408		m_zeroCount = end - x;
409	}
410	else
411		AttachedTransformation()->Put(begin, end-begin);
412
413	if (messageEnd)
414	{
415		m_possiblePadding = false;
416		Output(0, begin, length, messageEnd, blocking);
417	}
418	return 0;
419}
420
421NAMESPACE_END
422