1/* 2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Hugo Santos, hugosantos@gmail.com 8 */ 9 10 11#include "EndpointManager.h" 12 13#include <new> 14#include <unistd.h> 15 16#include <KernelExport.h> 17 18#include <NetUtilities.h> 19#include <tracing.h> 20 21#include "TCPEndpoint.h" 22 23 24//#define TRACE_ENDPOINT_MANAGER 25#ifdef TRACE_ENDPOINT_MANAGER 26# define TRACE(x) dprintf x 27#else 28# define TRACE(x) 29#endif 30 31#if TCP_TRACING 32# define ENDPOINT_TRACING 33#endif 34#ifdef ENDPOINT_TRACING 35namespace EndpointTracing { 36 37class Bind : public AbstractTraceEntry { 38public: 39 Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral) 40 : 41 fEndpoint(endpoint), 42 fEphemeral(ephemeral) 43 { 44 address.AsString(fAddress, sizeof(fAddress), true); 45 Initialized(); 46 } 47 48 Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral) 49 : 50 fEndpoint(endpoint), 51 fEphemeral(ephemeral) 52 { 53 address.AsString(fAddress, sizeof(fAddress), true); 54 Initialized(); 55 } 56 57 virtual void AddDump(TraceOutput& out) 58 { 59 out.Print("tcp:e:%p bind%s address %s", fEndpoint, 60 fEphemeral ? " ephemeral" : "", fAddress); 61 } 62 63protected: 64 TCPEndpoint* fEndpoint; 65 char fAddress[32]; 66 bool fEphemeral; 67}; 68 69class Connect : public AbstractTraceEntry { 70public: 71 Connect(TCPEndpoint* endpoint) 72 : 73 fEndpoint(endpoint) 74 { 75 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 76 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 77 Initialized(); 78 } 79 80 virtual void AddDump(TraceOutput& out) 81 { 82 out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal, 83 fPeer); 84 } 85 86protected: 87 TCPEndpoint* fEndpoint; 88 char fLocal[32]; 89 char fPeer[32]; 90}; 91 92class Unbind : public AbstractTraceEntry { 93public: 94 Unbind(TCPEndpoint* endpoint) 95 : 96 fEndpoint(endpoint) 97 { 98 //fStackTrace = capture_tracing_stack_trace(10, 0, false); 99 100 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 101 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 102 Initialized(); 103 } 104 105#if 0 106 virtual void DumpStackTrace(TraceOutput& out) 107 { 108 out.PrintStackTrace(fStackTrace); 109 } 110#endif 111 112 virtual void AddDump(TraceOutput& out) 113 { 114 out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal, 115 fPeer); 116 } 117 118protected: 119 TCPEndpoint* fEndpoint; 120 //tracing_stack_trace* fStackTrace; 121 char fLocal[32]; 122 char fPeer[32]; 123}; 124 125} // namespace EndpointTracing 126 127# define T(x) new(std::nothrow) EndpointTracing::x 128#else 129# define T(x) 130#endif // ENDPOINT_TRACING 131 132 133static const uint16 kLastReservedPort = 1023; 134static const uint16 kFirstEphemeralPort = 40000; 135 136 137ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager) 138 : 139 fManager(manager) 140{ 141} 142 143 144size_t 145ConnectionHashDefinition::HashKey(const KeyType& key) const 146{ 147 return ConstSocketAddress(fManager->AddressModule(), 148 key.first).HashPair(key.second); 149} 150 151 152size_t 153ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const 154{ 155 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()); 156} 157 158 159bool 160ConnectionHashDefinition::Compare(const KeyType& key, 161 TCPEndpoint* endpoint) const 162{ 163 return endpoint->LocalAddress().EqualTo(key.first, true) 164 && endpoint->PeerAddress().EqualTo(key.second, true); 165} 166 167 168TCPEndpoint*& 169ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const 170{ 171 return endpoint->fConnectionHashLink; 172} 173 174 175// #pragma mark - 176 177 178size_t 179EndpointHashDefinition::HashKey(uint16 port) const 180{ 181 return port; 182} 183 184 185size_t 186EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const 187{ 188 return endpoint->LocalAddress().Port(); 189} 190 191 192bool 193EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const 194{ 195 return endpoint->LocalAddress().Port() == port; 196} 197 198 199bool 200EndpointHashDefinition::CompareValues(TCPEndpoint* first, 201 TCPEndpoint* second) const 202{ 203 return first->LocalAddress().Port() == second->LocalAddress().Port(); 204} 205 206 207TCPEndpoint*& 208EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const 209{ 210 return endpoint->fEndpointHashLink; 211} 212 213 214// #pragma mark - 215 216 217EndpointManager::EndpointManager(net_domain* domain) 218 : 219 fDomain(domain), 220 fConnectionHash(this), 221 fLastPort(kFirstEphemeralPort) 222{ 223 rw_lock_init(&fLock, "TCP endpoint manager"); 224} 225 226 227EndpointManager::~EndpointManager() 228{ 229 rw_lock_destroy(&fLock); 230} 231 232 233status_t 234EndpointManager::Init() 235{ 236 status_t status = fConnectionHash.Init(); 237 if (status == B_OK) 238 status = fEndpointHash.Init(); 239 240 return status; 241} 242 243 244// #pragma mark - connections 245 246 247/*! Returns the endpoint matching the connection. 248 You must hold the manager's lock when calling this method (either read or 249 write). 250*/ 251TCPEndpoint* 252EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer) 253{ 254 return fConnectionHash.Lookup(std::make_pair(local, peer)); 255} 256 257 258status_t 259EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local, 260 const sockaddr* peer, const sockaddr* interfaceLocal) 261{ 262 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); 263 264 WriteLocker _(fLock); 265 266 SocketAddressStorage local(AddressModule()); 267 local.SetTo(_local); 268 269 if (local.IsEmpty(false)) { 270 uint16 port = local.Port(); 271 local.SetTo(interfaceLocal); 272 local.SetPort(port); 273 } 274 275 if (_LookupConnection(*local, peer) != NULL) 276 return EADDRINUSE; 277 278 endpoint->LocalAddress().SetTo(*local); 279 endpoint->PeerAddress().SetTo(peer); 280 T(Connect(endpoint)); 281 282 fConnectionHash.Insert(endpoint); 283 return B_OK; 284} 285 286 287status_t 288EndpointManager::SetPassive(TCPEndpoint* endpoint) 289{ 290 WriteLocker _(fLock); 291 292 if (!endpoint->IsBound()) { 293 // if the socket is unbound first bind it to ephemeral 294 SocketAddressStorage local(AddressModule()); 295 local.SetToEmpty(); 296 297 status_t status = _BindToEphemeral(endpoint, *local); 298 if (status < B_OK) 299 return status; 300 } 301 302 SocketAddressStorage passive(AddressModule()); 303 passive.SetToEmpty(); 304 305 if (_LookupConnection(*endpoint->LocalAddress(), *passive)) 306 return EADDRINUSE; 307 308 endpoint->PeerAddress().SetTo(*passive); 309 fConnectionHash.Insert(endpoint); 310 return B_OK; 311} 312 313 314TCPEndpoint* 315EndpointManager::FindConnection(sockaddr* local, sockaddr* peer) 316{ 317 ReadLocker _(fLock); 318 319 TCPEndpoint *endpoint = _LookupConnection(local, peer); 320 if (endpoint != NULL) { 321 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", 322 endpoint)); 323 if (gSocketModule->acquire_socket(endpoint->socket)) 324 return endpoint; 325 } 326 327 // no explicit endpoint exists, check for wildcard endpoints 328 329 SocketAddressStorage wildcard(AddressModule()); 330 wildcard.SetToEmpty(); 331 332 endpoint = _LookupConnection(local, *wildcard); 333 if (endpoint != NULL) { 334 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", 335 endpoint)); 336 if (gSocketModule->acquire_socket(endpoint->socket)) 337 return endpoint; 338 } 339 340 SocketAddressStorage localWildcard(AddressModule()); 341 localWildcard.SetToEmpty(); 342 localWildcard.SetPort(AddressModule()->get_port(local)); 343 344 endpoint = _LookupConnection(*localWildcard, *wildcard); 345 if (endpoint != NULL) { 346 TRACE(("TCP: Received packet corresponds to local wildcard endpoint " 347 "%p\n", endpoint)); 348 if (gSocketModule->acquire_socket(endpoint->socket)) 349 return endpoint; 350 } 351 352 // no matching endpoint exists 353 TRACE(("TCP: no matching endpoint!\n")); 354 355 return NULL; 356} 357 358 359// #pragma mark - endpoints 360 361 362status_t 363EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address) 364{ 365 // check the family 366 if (!AddressModule()->is_same_family(address)) 367 return EAFNOSUPPORT; 368 369 WriteLocker locker(fLock); 370 371 if (AddressModule()->get_port(address) == 0) 372 return _BindToEphemeral(endpoint, address); 373 374 return _BindToAddress(locker, endpoint, address); 375} 376 377 378status_t 379EndpointManager::BindChild(TCPEndpoint* endpoint) 380{ 381 WriteLocker _(fLock); 382 return _Bind(endpoint, *endpoint->LocalAddress()); 383} 384 385 386/*! You must have fLock write locked when calling this method. */ 387status_t 388EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint, 389 const sockaddr* _address) 390{ 391 ConstSocketAddress address(AddressModule(), _address); 392 uint16 port = address.Port(); 393 394 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint)); 395 T(Bind(endpoint, address, false)); 396 397 // TODO: this check follows very typical UNIX semantics 398 // and generally should be improved. 399 if (ntohs(port) <= kLastReservedPort && geteuid() != 0) 400 return B_PERMISSION_DENIED; 401 402 bool retrying = false; 403 int32 retry = 0; 404 do { 405 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port); 406 retry = false; 407 408 while (portUsers.HasNext()) { 409 TCPEndpoint* user = portUsers.Next(); 410 411 if (user->LocalAddress().IsEmpty(false) 412 || address.EqualTo(*user->LocalAddress(), false)) { 413 // Check if this belongs to a local connection 414 415 // Note, while we hold our lock, the endpoint cannot go away, 416 // it can only change its state - IsLocal() is safe to be used 417 // without having the endpoint locked. 418 tcp_state userState = user->State(); 419 if (user->IsLocal() 420 && (userState > ESTABLISHED || userState == CLOSED)) { 421 // This is a closing local connection - wait until it's 422 // gone away for real 423 locker.Unlock(); 424 snooze(10000); 425 locker.Lock(); 426 // TODO: make this better 427 if (!retrying) { 428 retrying = true; 429 retry = 5; 430 } 431 break; 432 } 433 434 if ((endpoint->socket->options & SO_REUSEADDR) == 0) 435 return EADDRINUSE; 436 437 if (userState != TIME_WAIT && userState != CLOSED) 438 return EADDRINUSE; 439 } 440 } 441 } while (retry-- > 0); 442 443 return _Bind(endpoint, *address); 444} 445 446 447/*! You must have fLock write locked when calling this method. */ 448status_t 449EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint, 450 const sockaddr* address) 451{ 452 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint)); 453 454 uint32 max = fLastPort + 65536; 455 456 for (int32 i = 1; i < 5; i++) { 457 // try to retrieve a more or less random port 458 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1; 459 uint32 counter = fLastPort + step; 460 461 while (counter < max) { 462 uint16 port = counter & 0xffff; 463 if (port <= kLastReservedPort) 464 port += kLastReservedPort; 465 466 fLastPort = port; 467 port = htons(port); 468 469 if (!fEndpointHash.Lookup(port).HasNext()) { 470 // found a port 471 SocketAddressStorage newAddress(AddressModule()); 472 newAddress.SetTo(address); 473 newAddress.SetPort(port); 474 475 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", 476 endpoint, AddressString(Domain(), *newAddress, 477 true).Data())); 478 T(Bind(endpoint, newAddress, true)); 479 480 return _Bind(endpoint, *newAddress); 481 } 482 483 counter += step; 484 } 485 } 486 487 // could not find a port! 488 return EADDRINUSE; 489} 490 491 492status_t 493EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address) 494{ 495 // Thus far we have checked if the Bind() is allowed 496 497 status_t status = endpoint->next->module->bind(endpoint->next, address); 498 if (status < B_OK) 499 return status; 500 501 fEndpointHash.Insert(endpoint); 502 503 return B_OK; 504} 505 506 507status_t 508EndpointManager::Unbind(TCPEndpoint* endpoint) 509{ 510 TRACE(("EndpointManager::Unbind(%p)\n", endpoint)); 511 T(Unbind(endpoint)); 512 513 if (endpoint == NULL || !endpoint->IsBound()) { 514 TRACE((" endpoint is unbound.\n")); 515 return B_BAD_VALUE; 516 } 517 518 WriteLocker _(fLock); 519 520 if (!fEndpointHash.Remove(endpoint)) 521 panic("bound endpoint %p not in hash!", endpoint); 522 523 fConnectionHash.Remove(endpoint); 524 525 (*endpoint->LocalAddress())->sa_len = 0; 526 527 return B_OK; 528} 529 530 531status_t 532EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer) 533{ 534 TRACE(("TCP: Sending RST...\n")); 535 536 net_buffer* reply = gBufferModule->create(512); 537 if (reply == NULL) 538 return B_NO_MEMORY; 539 540 AddressModule()->set_to(reply->source, buffer->destination); 541 AddressModule()->set_to(reply->destination, buffer->source); 542 543 tcp_segment_header outSegment(TCP_FLAG_RESET); 544 outSegment.sequence = 0; 545 outSegment.acknowledge = 0; 546 outSegment.advertised_window = 0; 547 outSegment.urgent_offset = 0; 548 549 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 550 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 551 outSegment.acknowledge = segment.sequence + buffer->size; 552 // TODO: Confirm: 553 if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0) 554 outSegment.acknowledge++; 555 } else 556 outSegment.sequence = segment.acknowledge; 557 558 status_t status = add_tcp_header(AddressModule(), outSegment, reply); 559 if (status == B_OK) 560 status = Domain()->module->send_data(NULL, reply); 561 562 if (status != B_OK) 563 gBufferModule->free(reply); 564 565 return status; 566} 567 568 569void 570EndpointManager::Dump() const 571{ 572 kprintf("-------- TCP Domain %p ---------\n", this); 573 kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer", 574 "recv-q", "send-q", "state"); 575 576 ConnectionTable::Iterator iterator = fConnectionHash.GetIterator(); 577 578 while (iterator.HasNext()) { 579 TCPEndpoint *endpoint = iterator.Next(); 580 581 char localBuf[64], peerBuf[64]; 582 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true); 583 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true); 584 585 kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf, 586 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(), 587 name_for_state(endpoint->State())); 588 } 589} 590 591