1/*
2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 *		Axel Dörfler, axeld@pinc-software.de
7 *		Hugo Santos, hugosantos@gmail.com
8 */
9
10
11#include "EndpointManager.h"
12
13#include <new>
14#include <unistd.h>
15
16#include <KernelExport.h>
17
18#include <NetUtilities.h>
19#include <tracing.h>
20
21#include "TCPEndpoint.h"
22
23
24//#define TRACE_ENDPOINT_MANAGER
25#ifdef TRACE_ENDPOINT_MANAGER
26#	define TRACE(x) dprintf x
27#else
28#	define TRACE(x)
29#endif
30
31#if TCP_TRACING
32#	define ENDPOINT_TRACING
33#endif
34#ifdef ENDPOINT_TRACING
35namespace EndpointTracing {
36
37class Bind : public AbstractTraceEntry {
38public:
39	Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
40		:
41		fEndpoint(endpoint),
42		fEphemeral(ephemeral)
43	{
44		address.AsString(fAddress, sizeof(fAddress), true);
45		Initialized();
46	}
47
48	Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
49		:
50		fEndpoint(endpoint),
51		fEphemeral(ephemeral)
52	{
53		address.AsString(fAddress, sizeof(fAddress), true);
54		Initialized();
55	}
56
57	virtual void AddDump(TraceOutput& out)
58	{
59		out.Print("tcp:e:%p bind%s address %s", fEndpoint,
60			fEphemeral ? " ephemeral" : "", fAddress);
61	}
62
63protected:
64	TCPEndpoint*	fEndpoint;
65	char			fAddress[32];
66	bool			fEphemeral;
67};
68
69class Connect : public AbstractTraceEntry {
70public:
71	Connect(TCPEndpoint* endpoint)
72		:
73		fEndpoint(endpoint)
74	{
75		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
76		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
77		Initialized();
78	}
79
80	virtual void AddDump(TraceOutput& out)
81	{
82		out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
83			fPeer);
84	}
85
86protected:
87	TCPEndpoint*	fEndpoint;
88	char			fLocal[32];
89	char			fPeer[32];
90};
91
92class Unbind : public AbstractTraceEntry {
93public:
94	Unbind(TCPEndpoint* endpoint)
95		:
96		fEndpoint(endpoint)
97	{
98		//fStackTrace = capture_tracing_stack_trace(10, 0, false);
99
100		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
101		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
102		Initialized();
103	}
104
105#if 0
106	virtual void DumpStackTrace(TraceOutput& out)
107	{
108		out.PrintStackTrace(fStackTrace);
109	}
110#endif
111
112	virtual void AddDump(TraceOutput& out)
113	{
114		out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
115			fPeer);
116	}
117
118protected:
119	TCPEndpoint*	fEndpoint;
120	//tracing_stack_trace* fStackTrace;
121	char			fLocal[32];
122	char			fPeer[32];
123};
124
125}	// namespace EndpointTracing
126
127#	define T(x)	new(std::nothrow) EndpointTracing::x
128#else
129#	define T(x)
130#endif	// ENDPOINT_TRACING
131
132
133static const uint16 kLastReservedPort = 1023;
134static const uint16 kFirstEphemeralPort = 40000;
135
136
137ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
138	:
139	fManager(manager)
140{
141}
142
143
144size_t
145ConnectionHashDefinition::HashKey(const KeyType& key) const
146{
147	return ConstSocketAddress(fManager->AddressModule(),
148		key.first).HashPair(key.second);
149}
150
151
152size_t
153ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
154{
155	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
156}
157
158
159bool
160ConnectionHashDefinition::Compare(const KeyType& key,
161	TCPEndpoint* endpoint) const
162{
163	return endpoint->LocalAddress().EqualTo(key.first, true)
164		&& endpoint->PeerAddress().EqualTo(key.second, true);
165}
166
167
168TCPEndpoint*&
169ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
170{
171	return endpoint->fConnectionHashLink;
172}
173
174
175//	#pragma mark -
176
177
178size_t
179EndpointHashDefinition::HashKey(uint16 port) const
180{
181	return port;
182}
183
184
185size_t
186EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
187{
188	return endpoint->LocalAddress().Port();
189}
190
191
192bool
193EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
194{
195	return endpoint->LocalAddress().Port() == port;
196}
197
198
199bool
200EndpointHashDefinition::CompareValues(TCPEndpoint* first,
201	TCPEndpoint* second) const
202{
203	return first->LocalAddress().Port() == second->LocalAddress().Port();
204}
205
206
207TCPEndpoint*&
208EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
209{
210	return endpoint->fEndpointHashLink;
211}
212
213
214//	#pragma mark -
215
216
217EndpointManager::EndpointManager(net_domain* domain)
218	:
219	fDomain(domain),
220	fConnectionHash(this),
221	fLastPort(kFirstEphemeralPort)
222{
223	rw_lock_init(&fLock, "TCP endpoint manager");
224}
225
226
227EndpointManager::~EndpointManager()
228{
229	rw_lock_destroy(&fLock);
230}
231
232
233status_t
234EndpointManager::Init()
235{
236	status_t status = fConnectionHash.Init();
237	if (status == B_OK)
238		status = fEndpointHash.Init();
239
240	return status;
241}
242
243
244//	#pragma mark - connections
245
246
247/*!	Returns the endpoint matching the connection.
248	You must hold the manager's lock when calling this method (either read or
249	write).
250*/
251TCPEndpoint*
252EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
253{
254	return fConnectionHash.Lookup(std::make_pair(local, peer));
255}
256
257
258status_t
259EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
260	const sockaddr* peer, const sockaddr* interfaceLocal)
261{
262	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
263
264	WriteLocker _(fLock);
265
266	SocketAddressStorage local(AddressModule());
267	local.SetTo(_local);
268
269	if (local.IsEmpty(false)) {
270		uint16 port = local.Port();
271		local.SetTo(interfaceLocal);
272		local.SetPort(port);
273	}
274
275	if (_LookupConnection(*local, peer) != NULL)
276		return EADDRINUSE;
277
278	endpoint->LocalAddress().SetTo(*local);
279	endpoint->PeerAddress().SetTo(peer);
280	T(Connect(endpoint));
281
282	fConnectionHash.Insert(endpoint);
283	return B_OK;
284}
285
286
287status_t
288EndpointManager::SetPassive(TCPEndpoint* endpoint)
289{
290	WriteLocker _(fLock);
291
292	if (!endpoint->IsBound()) {
293		// if the socket is unbound first bind it to ephemeral
294		SocketAddressStorage local(AddressModule());
295		local.SetToEmpty();
296
297		status_t status = _BindToEphemeral(endpoint, *local);
298		if (status < B_OK)
299			return status;
300	}
301
302	SocketAddressStorage passive(AddressModule());
303	passive.SetToEmpty();
304
305	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
306		return EADDRINUSE;
307
308	endpoint->PeerAddress().SetTo(*passive);
309	fConnectionHash.Insert(endpoint);
310	return B_OK;
311}
312
313
314TCPEndpoint*
315EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
316{
317	ReadLocker _(fLock);
318
319	TCPEndpoint *endpoint = _LookupConnection(local, peer);
320	if (endpoint != NULL) {
321		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n",
322			endpoint));
323		if (gSocketModule->acquire_socket(endpoint->socket))
324			return endpoint;
325	}
326
327	// no explicit endpoint exists, check for wildcard endpoints
328
329	SocketAddressStorage wildcard(AddressModule());
330	wildcard.SetToEmpty();
331
332	endpoint = _LookupConnection(local, *wildcard);
333	if (endpoint != NULL) {
334		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n",
335			endpoint));
336		if (gSocketModule->acquire_socket(endpoint->socket))
337			return endpoint;
338	}
339
340	SocketAddressStorage localWildcard(AddressModule());
341	localWildcard.SetToEmpty();
342	localWildcard.SetPort(AddressModule()->get_port(local));
343
344	endpoint = _LookupConnection(*localWildcard, *wildcard);
345	if (endpoint != NULL) {
346		TRACE(("TCP: Received packet corresponds to local wildcard endpoint "
347			"%p\n", endpoint));
348		if (gSocketModule->acquire_socket(endpoint->socket))
349			return endpoint;
350	}
351
352	// no matching endpoint exists
353	TRACE(("TCP: no matching endpoint!\n"));
354
355	return NULL;
356}
357
358
359//	#pragma mark - endpoints
360
361
362status_t
363EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
364{
365	// check the family
366	if (!AddressModule()->is_same_family(address))
367		return EAFNOSUPPORT;
368
369	WriteLocker locker(fLock);
370
371	if (AddressModule()->get_port(address) == 0)
372		return _BindToEphemeral(endpoint, address);
373
374	return _BindToAddress(locker, endpoint, address);
375}
376
377
378status_t
379EndpointManager::BindChild(TCPEndpoint* endpoint)
380{
381	WriteLocker _(fLock);
382	return _Bind(endpoint, *endpoint->LocalAddress());
383}
384
385
386/*! You must have fLock write locked when calling this method. */
387status_t
388EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint,
389	const sockaddr* _address)
390{
391	ConstSocketAddress address(AddressModule(), _address);
392	uint16 port = address.Port();
393
394	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
395	T(Bind(endpoint, address, false));
396
397	// TODO: this check follows very typical UNIX semantics
398	// and generally should be improved.
399	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
400		return B_PERMISSION_DENIED;
401
402	bool retrying = false;
403	int32 retry = 0;
404	do {
405		EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
406		retry = false;
407
408		while (portUsers.HasNext()) {
409			TCPEndpoint* user = portUsers.Next();
410
411			if (user->LocalAddress().IsEmpty(false)
412				|| address.EqualTo(*user->LocalAddress(), false)) {
413				// Check if this belongs to a local connection
414
415				// Note, while we hold our lock, the endpoint cannot go away,
416				// it can only change its state - IsLocal() is safe to be used
417				// without having the endpoint locked.
418				tcp_state userState = user->State();
419				if (user->IsLocal()
420					&& (userState > ESTABLISHED || userState == CLOSED)) {
421					// This is a closing local connection - wait until it's
422					// gone away for real
423					locker.Unlock();
424					snooze(10000);
425					locker.Lock();
426						// TODO: make this better
427					if (!retrying) {
428						retrying = true;
429						retry = 5;
430					}
431					break;
432				}
433
434				if ((endpoint->socket->options & SO_REUSEADDR) == 0)
435					return EADDRINUSE;
436
437				if (userState != TIME_WAIT && userState != CLOSED)
438					return EADDRINUSE;
439			}
440		}
441	} while (retry-- > 0);
442
443	return _Bind(endpoint, *address);
444}
445
446
447/*! You must have fLock write locked when calling this method. */
448status_t
449EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
450	const sockaddr* address)
451{
452	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
453
454	uint32 max = fLastPort + 65536;
455
456	for (int32 i = 1; i < 5; i++) {
457		// try to retrieve a more or less random port
458		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
459		uint32 counter = fLastPort + step;
460
461		while (counter < max) {
462			uint16 port = counter & 0xffff;
463			if (port <= kLastReservedPort)
464				port += kLastReservedPort;
465
466			fLastPort = port;
467			port = htons(port);
468
469			if (!fEndpointHash.Lookup(port).HasNext()) {
470				// found a port
471				SocketAddressStorage newAddress(AddressModule());
472				newAddress.SetTo(address);
473				newAddress.SetPort(port);
474
475				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n",
476					endpoint, AddressString(Domain(), *newAddress,
477					true).Data()));
478				T(Bind(endpoint, newAddress, true));
479
480				return _Bind(endpoint, *newAddress);
481			}
482
483			counter += step;
484		}
485	}
486
487	// could not find a port!
488	return EADDRINUSE;
489}
490
491
492status_t
493EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
494{
495	// Thus far we have checked if the Bind() is allowed
496
497	status_t status = endpoint->next->module->bind(endpoint->next, address);
498	if (status < B_OK)
499		return status;
500
501	fEndpointHash.Insert(endpoint);
502
503	return B_OK;
504}
505
506
507status_t
508EndpointManager::Unbind(TCPEndpoint* endpoint)
509{
510	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
511	T(Unbind(endpoint));
512
513	if (endpoint == NULL || !endpoint->IsBound()) {
514		TRACE(("  endpoint is unbound.\n"));
515		return B_BAD_VALUE;
516	}
517
518	WriteLocker _(fLock);
519
520	if (!fEndpointHash.Remove(endpoint))
521		panic("bound endpoint %p not in hash!", endpoint);
522
523	fConnectionHash.Remove(endpoint);
524
525	(*endpoint->LocalAddress())->sa_len = 0;
526
527	return B_OK;
528}
529
530
531status_t
532EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
533{
534	TRACE(("TCP: Sending RST...\n"));
535
536	net_buffer* reply = gBufferModule->create(512);
537	if (reply == NULL)
538		return B_NO_MEMORY;
539
540	AddressModule()->set_to(reply->source, buffer->destination);
541	AddressModule()->set_to(reply->destination, buffer->source);
542
543	tcp_segment_header outSegment(TCP_FLAG_RESET);
544	outSegment.sequence = 0;
545	outSegment.acknowledge = 0;
546	outSegment.advertised_window = 0;
547	outSegment.urgent_offset = 0;
548
549	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
550		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
551		outSegment.acknowledge = segment.sequence + buffer->size;
552		// TODO: Confirm:
553		if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
554			outSegment.acknowledge++;
555	} else
556		outSegment.sequence = segment.acknowledge;
557
558	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
559	if (status == B_OK)
560		status = Domain()->module->send_data(NULL, reply);
561
562	if (status != B_OK)
563		gBufferModule->free(reply);
564
565	return status;
566}
567
568
569void
570EndpointManager::Dump() const
571{
572	kprintf("-------- TCP Domain %p ---------\n", this);
573	kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer",
574		"recv-q", "send-q", "state");
575
576	ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
577
578	while (iterator.HasNext()) {
579		TCPEndpoint *endpoint = iterator.Next();
580
581		char localBuf[64], peerBuf[64];
582		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
583		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
584
585		kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
586			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
587			name_for_state(endpoint->State()));
588	}
589}
590
591