1/*
2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 *		Axel D��rfler, axeld@pinc-software.de
7 *		Hugo Santos, hugosantos@gmail.com
8 */
9
10
11#include "EndpointManager.h"
12
13#include <new>
14#include <unistd.h>
15
16#include <KernelExport.h>
17
18#include <NetUtilities.h>
19#include <tracing.h>
20
21#include "TCPEndpoint.h"
22
23
24//#define TRACE_ENDPOINT_MANAGER
25#ifdef TRACE_ENDPOINT_MANAGER
26#	define TRACE(x) dprintf x
27#else
28#	define TRACE(x)
29#endif
30
31#if TCP_TRACING
32#	define ENDPOINT_TRACING
33#endif
34#ifdef ENDPOINT_TRACING
35namespace EndpointTracing {
36
37class Bind : public AbstractTraceEntry {
38public:
39	Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
40		:
41		fEndpoint(endpoint),
42		fEphemeral(ephemeral)
43	{
44		address.AsString(fAddress, sizeof(fAddress), true);
45		Initialized();
46	}
47
48	Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
49		:
50		fEndpoint(endpoint),
51		fEphemeral(ephemeral)
52	{
53		address.AsString(fAddress, sizeof(fAddress), true);
54		Initialized();
55	}
56
57	virtual void AddDump(TraceOutput& out)
58	{
59		out.Print("tcp:e:%p bind%s address %s", fEndpoint,
60			fEphemeral ? " ephemeral" : "", fAddress);
61	}
62
63protected:
64	TCPEndpoint*	fEndpoint;
65	char			fAddress[32];
66	bool			fEphemeral;
67};
68
69class Connect : public AbstractTraceEntry {
70public:
71	Connect(TCPEndpoint* endpoint)
72		:
73		fEndpoint(endpoint)
74	{
75		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
76		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
77		Initialized();
78	}
79
80	virtual void AddDump(TraceOutput& out)
81	{
82		out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
83			fPeer);
84	}
85
86protected:
87	TCPEndpoint*	fEndpoint;
88	char			fLocal[32];
89	char			fPeer[32];
90};
91
92class Unbind : public AbstractTraceEntry {
93public:
94	Unbind(TCPEndpoint* endpoint)
95		:
96		fEndpoint(endpoint)
97	{
98		//fStackTrace = capture_tracing_stack_trace(10, 0, false);
99
100		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
101		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
102		Initialized();
103	}
104
105#if 0
106	virtual void DumpStackTrace(TraceOutput& out)
107	{
108		out.PrintStackTrace(fStackTrace);
109	}
110#endif
111
112	virtual void AddDump(TraceOutput& out)
113	{
114		out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
115			fPeer);
116	}
117
118protected:
119	TCPEndpoint*	fEndpoint;
120	//tracing_stack_trace* fStackTrace;
121	char			fLocal[32];
122	char			fPeer[32];
123};
124
125}	// namespace EndpointTracing
126
127#	define T(x)	new(std::nothrow) EndpointTracing::x
128#else
129#	define T(x)
130#endif	// ENDPOINT_TRACING
131
132
133static const uint16 kLastReservedPort = 1023;
134static const uint16 kFirstEphemeralPort = 40000;
135
136
137ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
138	:
139	fManager(manager)
140{
141}
142
143
144size_t
145ConnectionHashDefinition::HashKey(const KeyType& key) const
146{
147	return ConstSocketAddress(fManager->AddressModule(),
148		key.first).HashPair(key.second);
149}
150
151
152size_t
153ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
154{
155	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
156}
157
158
159bool
160ConnectionHashDefinition::Compare(const KeyType& key,
161	TCPEndpoint* endpoint) const
162{
163	return endpoint->LocalAddress().EqualTo(key.first, true)
164		&& endpoint->PeerAddress().EqualTo(key.second, true);
165}
166
167
168TCPEndpoint*&
169ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
170{
171	return endpoint->fConnectionHashLink;
172}
173
174
175//	#pragma mark -
176
177
178size_t
179EndpointHashDefinition::HashKey(uint16 port) const
180{
181	return port;
182}
183
184
185size_t
186EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
187{
188	return endpoint->LocalAddress().Port();
189}
190
191
192bool
193EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
194{
195	return endpoint->LocalAddress().Port() == port;
196}
197
198
199bool
200EndpointHashDefinition::CompareValues(TCPEndpoint* first,
201	TCPEndpoint* second) const
202{
203	return first->LocalAddress().Port() == second->LocalAddress().Port();
204}
205
206
207TCPEndpoint*&
208EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
209{
210	return endpoint->fEndpointHashLink;
211}
212
213
214//	#pragma mark -
215
216
217EndpointManager::EndpointManager(net_domain* domain)
218	:
219	fDomain(domain),
220	fConnectionHash(this),
221	fLastPort(kFirstEphemeralPort)
222{
223	rw_lock_init(&fLock, "TCP endpoint manager");
224}
225
226
227EndpointManager::~EndpointManager()
228{
229	rw_lock_destroy(&fLock);
230}
231
232
233status_t
234EndpointManager::Init()
235{
236	status_t status = fConnectionHash.Init();
237	if (status == B_OK)
238		status = fEndpointHash.Init();
239
240	return status;
241}
242
243
244//	#pragma mark - connections
245
246
247/*!	Returns the endpoint matching the connection.
248	You must hold the manager's lock when calling this method (either read or
249	write).
250*/
251TCPEndpoint*
252EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
253{
254	return fConnectionHash.Lookup(std::make_pair(local, peer));
255}
256
257
258status_t
259EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
260	const sockaddr* peer, const sockaddr* interfaceLocal)
261{
262	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
263
264	WriteLocker _(fLock);
265
266	SocketAddressStorage local(AddressModule());
267	local.SetTo(_local);
268
269	if (local.IsEmpty(false)) {
270		uint16 port = local.Port();
271		local.SetTo(interfaceLocal);
272		local.SetPort(port);
273	}
274
275	// We want to create a connection for (local, peer), so check to make sure
276	// that this pair is not already in use by an existing connection.
277	if (_LookupConnection(*local, peer) != NULL)
278		return EADDRINUSE;
279
280	endpoint->LocalAddress().SetTo(*local);
281	endpoint->PeerAddress().SetTo(peer);
282	T(Connect(endpoint));
283
284	// BOpenHashTable doesn't support inserting duplicate objects. Since
285	// BOpenHashTable is a chained hash table where the items are required to
286	// be intrusive linked list nodes, inserting the same object twice will
287	// create a cycle in the linked list, which is not handled currently.
288	//
289	// We need to makes sure to remove any existing copy of this endpoint
290	// object from the table in order to handle calling connect() on a closed
291	// socket to connect to a different remote (address, port) than it was
292	// originally used for.
293	//
294	// We use RemoveUnchecked here because we don't want the hash table to
295	// resize itself after this removal when we are planning to just add
296	// another.
297	fConnectionHash.RemoveUnchecked(endpoint);
298
299	fConnectionHash.Insert(endpoint);
300	return B_OK;
301}
302
303
304status_t
305EndpointManager::SetPassive(TCPEndpoint* endpoint)
306{
307	WriteLocker _(fLock);
308
309	if (!endpoint->IsBound()) {
310		// if the socket is unbound first bind it to ephemeral
311		SocketAddressStorage local(AddressModule());
312		local.SetToEmpty();
313
314		status_t status = _BindToEphemeral(endpoint, *local);
315		if (status < B_OK)
316			return status;
317	}
318
319	SocketAddressStorage passive(AddressModule());
320	passive.SetToEmpty();
321
322	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
323		return EADDRINUSE;
324
325	endpoint->PeerAddress().SetTo(*passive);
326	fConnectionHash.Insert(endpoint);
327	return B_OK;
328}
329
330
331TCPEndpoint*
332EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
333{
334	ReadLocker _(fLock);
335
336	TCPEndpoint *endpoint = _LookupConnection(local, peer);
337	if (endpoint != NULL) {
338		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n",
339			endpoint));
340		if (gSocketModule->acquire_socket(endpoint->socket))
341			return endpoint;
342	}
343
344	// no explicit endpoint exists, check for wildcard endpoints
345
346	SocketAddressStorage wildcard(AddressModule());
347	wildcard.SetToEmpty();
348
349	endpoint = _LookupConnection(local, *wildcard);
350	if (endpoint != NULL) {
351		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n",
352			endpoint));
353		if (gSocketModule->acquire_socket(endpoint->socket))
354			return endpoint;
355	}
356
357	SocketAddressStorage localWildcard(AddressModule());
358	localWildcard.SetToEmpty();
359	localWildcard.SetPort(AddressModule()->get_port(local));
360
361	endpoint = _LookupConnection(*localWildcard, *wildcard);
362	if (endpoint != NULL) {
363		TRACE(("TCP: Received packet corresponds to local wildcard endpoint "
364			"%p\n", endpoint));
365		if (gSocketModule->acquire_socket(endpoint->socket))
366			return endpoint;
367	}
368
369	// no matching endpoint exists
370	TRACE(("TCP: no matching endpoint!\n"));
371
372	return NULL;
373}
374
375
376//	#pragma mark - endpoints
377
378
379status_t
380EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
381{
382	// check the family
383	if (!AddressModule()->is_same_family(address))
384		return EAFNOSUPPORT;
385
386	WriteLocker locker(fLock);
387
388	if (AddressModule()->get_port(address) == 0)
389		return _BindToEphemeral(endpoint, address);
390
391	return _BindToAddress(locker, endpoint, address);
392}
393
394
395status_t
396EndpointManager::BindChild(TCPEndpoint* endpoint, const sockaddr* address)
397{
398	WriteLocker _(fLock);
399	return _Bind(endpoint, address);
400}
401
402
403/*! You must have fLock write locked when calling this method. */
404status_t
405EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint,
406	const sockaddr* _address)
407{
408	ConstSocketAddress address(AddressModule(), _address);
409	uint16 port = address.Port();
410
411	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
412	T(Bind(endpoint, address, false));
413
414	// TODO: this check follows very typical UNIX semantics
415	// and generally should be improved.
416	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
417		return B_PERMISSION_DENIED;
418
419	bool retrying = false;
420	int32 retry = 0;
421	do {
422		EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
423		retry = false;
424
425		while (portUsers.HasNext()) {
426			TCPEndpoint* user = portUsers.Next();
427
428			if (user->LocalAddress().IsEmpty(false)
429				|| address.EqualTo(*user->LocalAddress(), false)) {
430				// Check if this belongs to a local connection
431
432				// Note, while we hold our lock, the endpoint cannot go away,
433				// it can only change its state - IsLocal() is safe to be used
434				// without having the endpoint locked.
435				tcp_state userState = user->State();
436				if (user->IsLocal()
437					&& (userState > ESTABLISHED || userState == CLOSED)) {
438					// This is a closing local connection - wait until it's
439					// gone away for real
440					locker.Unlock();
441					snooze(10000);
442					locker.Lock();
443						// TODO: make this better
444					if (!retrying) {
445						retrying = true;
446						retry = 5;
447					}
448					break;
449				}
450
451				if ((endpoint->socket->options & SO_REUSEADDR) == 0)
452					return EADDRINUSE;
453
454				if (userState != TIME_WAIT && userState != CLOSED)
455					return EADDRINUSE;
456			}
457		}
458	} while (retry-- > 0);
459
460	return _Bind(endpoint, *address);
461}
462
463
464/*! You must have fLock write locked when calling this method. */
465status_t
466EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
467	const sockaddr* address)
468{
469	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
470
471	uint32 max = fLastPort + 65536;
472
473	for (int32 i = 1; i < 5; i++) {
474		// try to retrieve a more or less random port
475		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
476		uint32 counter = fLastPort + step;
477
478		while (counter < max) {
479			uint16 port = counter & 0xffff;
480			if (port <= kLastReservedPort)
481				port += kLastReservedPort;
482
483			fLastPort = port;
484			port = htons(port);
485
486			if (!fEndpointHash.Lookup(port).HasNext()) {
487				// found a port
488				SocketAddressStorage newAddress(AddressModule());
489				newAddress.SetTo(address);
490				newAddress.SetPort(port);
491
492				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n",
493					endpoint, AddressString(Domain(), *newAddress,
494					true).Data()));
495				T(Bind(endpoint, newAddress, true));
496
497				return _Bind(endpoint, *newAddress);
498			}
499
500			counter += step;
501		}
502	}
503
504	// could not find a port!
505	return EADDRINUSE;
506}
507
508
509status_t
510EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
511{
512	// Thus far we have checked if the Bind() is allowed
513
514	status_t status = endpoint->next->module->bind(endpoint->next, address);
515	if (status < B_OK)
516		return status;
517
518	fEndpointHash.Insert(endpoint);
519
520	return B_OK;
521}
522
523
524status_t
525EndpointManager::Unbind(TCPEndpoint* endpoint)
526{
527	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
528	T(Unbind(endpoint));
529
530	if (endpoint == NULL || !endpoint->IsBound()) {
531		TRACE(("  endpoint is unbound.\n"));
532		return B_BAD_VALUE;
533	}
534
535	WriteLocker _(fLock);
536
537	if (!fEndpointHash.Remove(endpoint))
538		panic("bound endpoint %p not in hash!", endpoint);
539
540	fConnectionHash.Remove(endpoint);
541
542	(*endpoint->LocalAddress())->sa_len = 0;
543
544	return B_OK;
545}
546
547
548status_t
549EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
550{
551	TRACE(("TCP: Sending RST...\n"));
552
553	net_buffer* reply = gBufferModule->create(512);
554	if (reply == NULL)
555		return B_NO_MEMORY;
556
557	AddressModule()->set_to(reply->source, buffer->destination);
558	AddressModule()->set_to(reply->destination, buffer->source);
559
560	tcp_segment_header outSegment(TCP_FLAG_RESET);
561	outSegment.sequence = 0;
562	outSegment.acknowledge = 0;
563	outSegment.advertised_window = 0;
564	outSegment.urgent_offset = 0;
565
566	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
567		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
568		outSegment.acknowledge = segment.sequence + buffer->size;
569		// TODO: Confirm:
570		if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
571			outSegment.acknowledge++;
572	} else
573		outSegment.sequence = segment.acknowledge;
574
575	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
576	if (status == B_OK)
577		status = Domain()->module->send_data(NULL, reply);
578
579	if (status != B_OK)
580		gBufferModule->free(reply);
581
582	return status;
583}
584
585
586void
587EndpointManager::Dump() const
588{
589	kprintf("-------- TCP Domain %p ---------\n", this);
590	kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer",
591		"recv-q", "send-q", "state");
592
593	ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
594
595	while (iterator.HasNext()) {
596		TCPEndpoint *endpoint = iterator.Next();
597
598		char localBuf[64], peerBuf[64];
599		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
600		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
601
602		kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
603			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
604			name_for_state(endpoint->State()));
605	}
606}
607
608