1// network.cpp - written and placed in the public domain by Wei Dai
2
3#include "pch.h"
4#include "network.h"
5#include "wait.h"
6
7#define CRYPTOPP_TRACE_NETWORK 0
8
9NAMESPACE_BEGIN(CryptoPP)
10
11#ifdef HIGHRES_TIMER_AVAILABLE
12
13lword LimitedBandwidth::ComputeCurrentTransceiveLimit()
14{
15	if (!m_maxBytesPerSecond)
16		return ULONG_MAX;
17
18	double curTime = GetCurTimeAndCleanUp();
19	lword total = 0;
20	for (OpQueue::size_type i=0; i!=m_ops.size(); ++i)
21		total += m_ops[i].second;
22	return SaturatingSubtract(m_maxBytesPerSecond, total);
23}
24
25double LimitedBandwidth::TimeToNextTransceive()
26{
27	if (!m_maxBytesPerSecond)
28		return 0;
29
30	if (!m_nextTransceiveTime)
31		ComputeNextTransceiveTime();
32
33	return SaturatingSubtract(m_nextTransceiveTime, m_timer.ElapsedTimeAsDouble());
34}
35
36void LimitedBandwidth::NoteTransceive(lword size)
37{
38	if (m_maxBytesPerSecond)
39	{
40		double curTime = GetCurTimeAndCleanUp();
41		m_ops.push_back(std::make_pair(curTime, size));
42		m_nextTransceiveTime = 0;
43	}
44}
45
46void LimitedBandwidth::ComputeNextTransceiveTime()
47{
48	double curTime = GetCurTimeAndCleanUp();
49	lword total = 0;
50	for (unsigned int i=0; i!=m_ops.size(); ++i)
51		total += m_ops[i].second;
52	m_nextTransceiveTime =
53		(total < m_maxBytesPerSecond) ? curTime : m_ops.front().first + 1000;
54}
55
56double LimitedBandwidth::GetCurTimeAndCleanUp()
57{
58	if (!m_maxBytesPerSecond)
59		return 0;
60
61	double curTime = m_timer.ElapsedTimeAsDouble();
62	while (m_ops.size() && (m_ops.front().first + 1000 < curTime))
63		m_ops.pop_front();
64	return curTime;
65}
66
67void LimitedBandwidth::GetWaitObjects(WaitObjectContainer &container, const CallStack &callStack)
68{
69	double nextTransceiveTime = TimeToNextTransceive();
70	if (nextTransceiveTime)
71		container.ScheduleEvent(nextTransceiveTime, CallStack("LimitedBandwidth::GetWaitObjects()", &callStack));
72}
73
74// *************************************************************
75
76size_t NonblockingSource::GeneralPump2(
77	lword& byteCount, bool blockingOutput,
78	unsigned long maxTime, bool checkDelimiter, byte delimiter)
79{
80	m_blockedBySpeedLimit = false;
81
82	if (!GetMaxBytesPerSecond())
83	{
84		size_t ret = DoPump(byteCount, blockingOutput, maxTime, checkDelimiter, delimiter);
85		m_doPumpBlocked = (ret != 0);
86		return ret;
87	}
88
89	bool forever = (maxTime == INFINITE_TIME);
90	unsigned long timeToGo = maxTime;
91	Timer timer(Timer::MILLISECONDS, forever);
92	lword maxSize = byteCount;
93	byteCount = 0;
94
95	timer.StartTimer();
96
97	while (true)
98	{
99		lword curMaxSize = UnsignedMin(ComputeCurrentTransceiveLimit(), maxSize - byteCount);
100
101		if (curMaxSize || m_doPumpBlocked)
102		{
103			if (!forever) timeToGo = SaturatingSubtract(maxTime, timer.ElapsedTime());
104			size_t ret = DoPump(curMaxSize, blockingOutput, timeToGo, checkDelimiter, delimiter);
105			m_doPumpBlocked = (ret != 0);
106			if (curMaxSize)
107			{
108				NoteTransceive(curMaxSize);
109				byteCount += curMaxSize;
110			}
111			if (ret)
112				return ret;
113		}
114
115		if (maxSize != ULONG_MAX && byteCount >= maxSize)
116			break;
117
118		if (!forever)
119		{
120			timeToGo = SaturatingSubtract(maxTime, timer.ElapsedTime());
121			if (!timeToGo)
122				break;
123		}
124
125		double waitTime = TimeToNextTransceive();
126		if (!forever && waitTime > timeToGo)
127		{
128			m_blockedBySpeedLimit = true;
129			break;
130		}
131
132		WaitObjectContainer container;
133		LimitedBandwidth::GetWaitObjects(container, CallStack("NonblockingSource::GeneralPump2() - speed limit", 0));
134		container.Wait((unsigned long)waitTime);
135	}
136
137	return 0;
138}
139
140size_t NonblockingSource::PumpMessages2(unsigned int &messageCount, bool blocking)
141{
142	if (messageCount == 0)
143		return 0;
144
145	messageCount = 0;
146
147	lword byteCount;
148	do {
149		byteCount = LWORD_MAX;
150		RETURN_IF_NONZERO(Pump2(byteCount, blocking));
151	} while(byteCount == LWORD_MAX);
152
153	if (!m_messageEndSent && SourceExhausted())
154	{
155		RETURN_IF_NONZERO(AttachedTransformation()->Put2(NULL, 0, GetAutoSignalPropagation(), true));
156		m_messageEndSent = true;
157		messageCount = 1;
158	}
159	return 0;
160}
161
162lword NonblockingSink::TimedFlush(unsigned long maxTime, size_t targetSize)
163{
164	m_blockedBySpeedLimit = false;
165
166	size_t curBufSize = GetCurrentBufferSize();
167	if (curBufSize <= targetSize && (targetSize || !EofPending()))
168		return 0;
169
170	if (!GetMaxBytesPerSecond())
171		return DoFlush(maxTime, targetSize);
172
173	bool forever = (maxTime == INFINITE_TIME);
174	unsigned long timeToGo = maxTime;
175	Timer timer(Timer::MILLISECONDS, forever);
176	lword totalFlushed = 0;
177
178	timer.StartTimer();
179
180	while (true)
181	{
182		size_t flushSize = UnsignedMin(curBufSize - targetSize, ComputeCurrentTransceiveLimit());
183		if (flushSize || EofPending())
184		{
185			if (!forever) timeToGo = SaturatingSubtract(maxTime, timer.ElapsedTime());
186			size_t ret = (size_t)DoFlush(timeToGo, curBufSize - flushSize);
187			if (ret)
188			{
189				NoteTransceive(ret);
190				curBufSize -= ret;
191				totalFlushed += ret;
192			}
193		}
194
195		if (curBufSize <= targetSize && (targetSize || !EofPending()))
196			break;
197
198		if (!forever)
199		{
200			timeToGo = SaturatingSubtract(maxTime, timer.ElapsedTime());
201			if (!timeToGo)
202				break;
203		}
204
205		double waitTime = TimeToNextTransceive();
206		if (!forever && waitTime > timeToGo)
207		{
208			m_blockedBySpeedLimit = true;
209			break;
210		}
211
212		WaitObjectContainer container;
213		LimitedBandwidth::GetWaitObjects(container, CallStack("NonblockingSink::TimedFlush() - speed limit", 0));
214		container.Wait((unsigned long)waitTime);
215	}
216
217	return totalFlushed;
218}
219
220bool NonblockingSink::IsolatedFlush(bool hardFlush, bool blocking)
221{
222	TimedFlush(blocking ? INFINITE_TIME : 0);
223	return hardFlush && (!!GetCurrentBufferSize() || EofPending());
224}
225
226// *************************************************************
227
228NetworkSource::NetworkSource(BufferedTransformation *attachment)
229	: NonblockingSource(attachment), m_buf(1024*16)
230	, m_waitingForResult(false), m_outputBlocked(false)
231	, m_dataBegin(0), m_dataEnd(0)
232{
233}
234
235unsigned int NetworkSource::GetMaxWaitObjectCount() const
236{
237	return LimitedBandwidth::GetMaxWaitObjectCount()
238		+ GetReceiver().GetMaxWaitObjectCount()
239		+ AttachedTransformation()->GetMaxWaitObjectCount();
240}
241
242void NetworkSource::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
243{
244	if (BlockedBySpeedLimit())
245		LimitedBandwidth::GetWaitObjects(container, CallStack("NetworkSource::GetWaitObjects() - speed limit", &callStack));
246	else if (!m_outputBlocked)
247	{
248		if (m_dataBegin == m_dataEnd)
249			AccessReceiver().GetWaitObjects(container, CallStack("NetworkSource::GetWaitObjects() - no data", &callStack));
250		else
251			container.SetNoWait(CallStack("NetworkSource::GetWaitObjects() - have data", &callStack));
252	}
253
254	AttachedTransformation()->GetWaitObjects(container, CallStack("NetworkSource::GetWaitObjects() - attachment", &callStack));
255}
256
257size_t NetworkSource::DoPump(lword &byteCount, bool blockingOutput, unsigned long maxTime, bool checkDelimiter, byte delimiter)
258{
259	NetworkReceiver &receiver = AccessReceiver();
260
261	lword maxSize = byteCount;
262	byteCount = 0;
263	bool forever = maxTime == INFINITE_TIME;
264	Timer timer(Timer::MILLISECONDS, forever);
265	BufferedTransformation *t = AttachedTransformation();
266
267	if (m_outputBlocked)
268		goto DoOutput;
269
270	while (true)
271	{
272		if (m_dataBegin == m_dataEnd)
273		{
274			if (receiver.EofReceived())
275				break;
276
277			if (m_waitingForResult)
278			{
279				if (receiver.MustWaitForResult() &&
280					!receiver.Wait(SaturatingSubtract(maxTime, timer.ElapsedTime()),
281						CallStack("NetworkSource::DoPump() - wait receive result", 0)))
282					break;
283
284				unsigned int recvResult = receiver.GetReceiveResult();
285#if CRYPTOPP_TRACE_NETWORK
286				OutputDebugString((IntToString((unsigned int)this) + ": Received " + IntToString(recvResult) + " bytes\n").c_str());
287#endif
288				m_dataEnd += recvResult;
289				m_waitingForResult = false;
290
291				if (!receiver.MustWaitToReceive() && !receiver.EofReceived() && m_dataEnd != m_buf.size())
292					goto ReceiveNoWait;
293			}
294			else
295			{
296				m_dataEnd = m_dataBegin = 0;
297
298				if (receiver.MustWaitToReceive())
299				{
300					if (!receiver.Wait(SaturatingSubtract(maxTime, timer.ElapsedTime()),
301							CallStack("NetworkSource::DoPump() - wait receive", 0)))
302						break;
303
304					receiver.Receive(m_buf+m_dataEnd, m_buf.size()-m_dataEnd);
305					m_waitingForResult = true;
306				}
307				else
308				{
309ReceiveNoWait:
310					m_waitingForResult = true;
311					// call Receive repeatedly as long as data is immediately available,
312					// because some receivers tend to return data in small pieces
313#if CRYPTOPP_TRACE_NETWORK
314					OutputDebugString((IntToString((unsigned int)this) + ": Receiving " + IntToString(m_buf.size()-m_dataEnd) + " bytes\n").c_str());
315#endif
316					while (receiver.Receive(m_buf+m_dataEnd, m_buf.size()-m_dataEnd))
317					{
318						unsigned int recvResult = receiver.GetReceiveResult();
319#if CRYPTOPP_TRACE_NETWORK
320						OutputDebugString((IntToString((unsigned int)this) + ": Received " + IntToString(recvResult) + " bytes\n").c_str());
321#endif
322						m_dataEnd += recvResult;
323						if (receiver.EofReceived() || m_dataEnd > m_buf.size() /2)
324						{
325							m_waitingForResult = false;
326							break;
327						}
328					}
329				}
330			}
331		}
332		else
333		{
334			m_putSize = UnsignedMin(m_dataEnd - m_dataBegin, maxSize - byteCount);
335
336			if (checkDelimiter)
337				m_putSize = std::find(m_buf+m_dataBegin, m_buf+m_dataBegin+m_putSize, delimiter) - (m_buf+m_dataBegin);
338
339DoOutput:
340			size_t result = t->PutModifiable2(m_buf+m_dataBegin, m_putSize, 0, forever || blockingOutput);
341			if (result)
342			{
343				if (t->Wait(SaturatingSubtract(maxTime, timer.ElapsedTime()),
344						CallStack("NetworkSource::DoPump() - wait attachment", 0)))
345					goto DoOutput;
346				else
347				{
348					m_outputBlocked = true;
349					return result;
350				}
351			}
352			m_outputBlocked = false;
353
354			byteCount += m_putSize;
355			m_dataBegin += m_putSize;
356			if (checkDelimiter && m_dataBegin < m_dataEnd && m_buf[m_dataBegin] == delimiter)
357				break;
358			if (maxSize != ULONG_MAX && byteCount == maxSize)
359				break;
360			// once time limit is reached, return even if there is more data waiting
361			// but make 0 a special case so caller can request a large amount of data to be
362			// pumped as long as it is immediately available
363			if (maxTime > 0 && timer.ElapsedTime() > maxTime)
364				break;
365		}
366	}
367
368	return 0;
369}
370
371// *************************************************************
372
373NetworkSink::NetworkSink(unsigned int maxBufferSize, unsigned int autoFlushBound)
374	: m_maxBufferSize(maxBufferSize), m_autoFlushBound(autoFlushBound)
375	, m_needSendResult(false), m_wasBlocked(false), m_eofState(EOF_NONE)
376	, m_buffer(STDMIN(16U*1024U+256, maxBufferSize)), m_skipBytes(0)
377	, m_speedTimer(Timer::MILLISECONDS), m_byteCountSinceLastTimerReset(0)
378	, m_currentSpeed(0), m_maxObservedSpeed(0)
379{
380}
381
382float NetworkSink::ComputeCurrentSpeed()
383{
384	if (m_speedTimer.ElapsedTime() > 1000)
385	{
386		m_currentSpeed = m_byteCountSinceLastTimerReset * 1000 / m_speedTimer.ElapsedTime();
387		m_maxObservedSpeed = STDMAX(m_currentSpeed, m_maxObservedSpeed * 0.98f);
388		m_byteCountSinceLastTimerReset = 0;
389		m_speedTimer.StartTimer();
390//		OutputDebugString(("max speed: " + IntToString((int)m_maxObservedSpeed) + " current speed: " + IntToString((int)m_currentSpeed) + "\n").c_str());
391	}
392	return m_currentSpeed;
393}
394
395float NetworkSink::GetMaxObservedSpeed() const
396{
397	lword m = GetMaxBytesPerSecond();
398	return m ? STDMIN(m_maxObservedSpeed, float(CRYPTOPP_VC6_INT64 m)) : m_maxObservedSpeed;
399}
400
401unsigned int NetworkSink::GetMaxWaitObjectCount() const
402{
403	return LimitedBandwidth::GetMaxWaitObjectCount() + GetSender().GetMaxWaitObjectCount();
404}
405
406void NetworkSink::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
407{
408	if (BlockedBySpeedLimit())
409		LimitedBandwidth::GetWaitObjects(container, CallStack("NetworkSink::GetWaitObjects() - speed limit", &callStack));
410	else if (m_wasBlocked)
411		AccessSender().GetWaitObjects(container, CallStack("NetworkSink::GetWaitObjects() - was blocked", &callStack));
412	else if (!m_buffer.IsEmpty())
413		AccessSender().GetWaitObjects(container, CallStack("NetworkSink::GetWaitObjects() - buffer not empty", &callStack));
414	else if (EofPending())
415		AccessSender().GetWaitObjects(container, CallStack("NetworkSink::GetWaitObjects() - EOF pending", &callStack));
416}
417
418size_t NetworkSink::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
419{
420	if (m_eofState == EOF_DONE)
421	{
422		if (length || messageEnd)
423			throw Exception(Exception::OTHER_ERROR, "NetworkSink::Put2() being called after EOF had been sent");
424
425		return 0;
426	}
427
428	if (m_eofState > EOF_NONE)
429		goto EofSite;
430
431	{
432		if (m_skipBytes)
433		{
434			assert(length >= m_skipBytes);
435			inString += m_skipBytes;
436			length -= m_skipBytes;
437		}
438
439		m_buffer.Put(inString, length);
440
441		if (!blocking || m_buffer.CurrentSize() > m_autoFlushBound)
442			TimedFlush(0, 0);
443
444		size_t targetSize = messageEnd ? 0 : m_maxBufferSize;
445		if (blocking)
446			TimedFlush(INFINITE_TIME, targetSize);
447
448		if (m_buffer.CurrentSize() > targetSize)
449		{
450			assert(!blocking);
451			m_wasBlocked = true;
452			m_skipBytes += length;
453			size_t blockedBytes = UnsignedMin(length, m_buffer.CurrentSize() - targetSize);
454			return STDMAX<size_t>(blockedBytes, 1);
455		}
456
457		m_wasBlocked = false;
458		m_skipBytes = 0;
459	}
460
461	if (messageEnd)
462	{
463		m_eofState = EOF_PENDING_SEND;
464
465	EofSite:
466		TimedFlush(blocking ? INFINITE_TIME : 0, 0);
467		if (m_eofState != EOF_DONE)
468			return 1;
469	}
470
471	return 0;
472}
473
474lword NetworkSink::DoFlush(unsigned long maxTime, size_t targetSize)
475{
476	NetworkSender &sender = AccessSender();
477
478	bool forever = maxTime == INFINITE_TIME;
479	Timer timer(Timer::MILLISECONDS, forever);
480	unsigned int totalFlushSize = 0;
481
482	while (true)
483	{
484		if (m_buffer.CurrentSize() <= targetSize)
485			break;
486
487		if (m_needSendResult)
488		{
489			if (sender.MustWaitForResult() &&
490				!sender.Wait(SaturatingSubtract(maxTime, timer.ElapsedTime()),
491					CallStack("NetworkSink::DoFlush() - wait send result", 0)))
492				break;
493
494			unsigned int sendResult = sender.GetSendResult();
495#if CRYPTOPP_TRACE_NETWORK
496			OutputDebugString((IntToString((unsigned int)this) + ": Sent " + IntToString(sendResult) + " bytes\n").c_str());
497#endif
498			m_buffer.Skip(sendResult);
499			totalFlushSize += sendResult;
500			m_needSendResult = false;
501
502			if (!m_buffer.AnyRetrievable())
503				break;
504		}
505
506		unsigned long timeOut = maxTime ? SaturatingSubtract(maxTime, timer.ElapsedTime()) : 0;
507		if (sender.MustWaitToSend() && !sender.Wait(timeOut, CallStack("NetworkSink::DoFlush() - wait send", 0)))
508			break;
509
510		size_t contiguousSize = 0;
511		const byte *block = m_buffer.Spy(contiguousSize);
512
513#if CRYPTOPP_TRACE_NETWORK
514		OutputDebugString((IntToString((unsigned int)this) + ": Sending " + IntToString(contiguousSize) + " bytes\n").c_str());
515#endif
516		sender.Send(block, contiguousSize);
517		m_needSendResult = true;
518
519		if (maxTime > 0 && timeOut == 0)
520			break;	// once time limit is reached, return even if there is more data waiting
521	}
522
523	m_byteCountSinceLastTimerReset += totalFlushSize;
524	ComputeCurrentSpeed();
525
526	if (m_buffer.IsEmpty() && !m_needSendResult)
527	{
528		if (m_eofState == EOF_PENDING_SEND)
529		{
530			sender.SendEof();
531			m_eofState = sender.MustWaitForEof() ? EOF_PENDING_DELIVERY : EOF_DONE;
532		}
533
534		while (m_eofState == EOF_PENDING_DELIVERY)
535		{
536			unsigned long timeOut = maxTime ? SaturatingSubtract(maxTime, timer.ElapsedTime()) : 0;
537			if (!sender.Wait(timeOut, CallStack("NetworkSink::DoFlush() - wait EOF", 0)))
538				break;
539
540			if (sender.EofSent())
541				m_eofState = EOF_DONE;
542		}
543	}
544
545	return totalFlushSize;
546}
547
548#endif	// #ifdef HIGHRES_TIMER_AVAILABLE
549
550NAMESPACE_END
551