1/* 2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>. 3 * All rights reserved. Distributed under the terms of the MIT License. 4 */ 5 6 7#include <boot/net/UDP.h> 8 9#include <stdio.h> 10 11#include <KernelExport.h> 12 13#include <boot/net/ChainBuffer.h> 14#include <boot/net/NetStack.h> 15 16 17//#define TRACE_UDP 18#ifdef TRACE_UDP 19# define TRACE(x) dprintf x 20#else 21# define TRACE(x) ; 22#endif 23 24 25// #pragma mark - UDPPacket 26 27 28UDPPacket::UDPPacket() 29 : 30 fNext(NULL), 31 fData(NULL), 32 fSize(0) 33{ 34} 35 36 37UDPPacket::~UDPPacket() 38{ 39 free(fData); 40} 41 42 43status_t 44UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress, 45 uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort) 46{ 47 if (data == NULL) 48 return B_BAD_VALUE; 49 50 // clone the data 51 fData = malloc(size); 52 if (fData == NULL) 53 return B_NO_MEMORY; 54 memcpy(fData, data, size); 55 56 fSize = size; 57 fSourceAddress = sourceAddress; 58 fDestinationAddress = destinationAddress; 59 fSourcePort = sourcePort; 60 fDestinationPort = destinationPort; 61 62 return B_OK; 63} 64 65 66UDPPacket * 67UDPPacket::Next() const 68{ 69 return fNext; 70} 71 72 73void 74UDPPacket::SetNext(UDPPacket *next) 75{ 76 fNext = next; 77} 78 79 80const void * 81UDPPacket::Data() const 82{ 83 return fData; 84} 85 86 87size_t 88UDPPacket::DataSize() const 89{ 90 return fSize; 91} 92 93 94ip_addr_t 95UDPPacket::SourceAddress() const 96{ 97 return fSourceAddress; 98} 99 100 101uint16 102UDPPacket::SourcePort() const 103{ 104 return fSourcePort; 105} 106 107 108ip_addr_t 109UDPPacket::DestinationAddress() const 110{ 111 return fDestinationAddress; 112} 113 114 115uint16 116UDPPacket::DestinationPort() const 117{ 118 return fDestinationPort; 119} 120 121 122// #pragma mark - UDPSocket 123 124 125UDPSocket::UDPSocket() 126 : 127 fUDPService(NetStack::Default()->GetUDPService()), 128 fFirstPacket(NULL), 129 fLastPacket(NULL), 130 fAddress(INADDR_ANY), 131 fPort(0) 132{ 133} 134 135 136UDPSocket::~UDPSocket() 137{ 138 if (fPort != 0 && fUDPService != NULL) 139 fUDPService->UnbindSocket(this); 140} 141 142 143status_t 144UDPSocket::Bind(ip_addr_t address, uint16 port) 145{ 146 if (fUDPService == NULL) { 147 printf("UDPSocket::Bind(): no UDP service\n"); 148 return B_NO_INIT; 149 } 150 151 if (address == INADDR_BROADCAST || port == 0) { 152 printf("UDPSocket::Bind(): broadcast IP or port 0\n"); 153 return B_BAD_VALUE; 154 } 155 156 if (fPort != 0) { 157 printf("UDPSocket::Bind(): already bound\n"); 158 return EALREADY; 159 // correct code? 160 } 161 162 status_t error = fUDPService->BindSocket(this, address, port); 163 if (error != B_OK) { 164 printf("UDPSocket::Bind(): service BindSocket() failed\n"); 165 return error; 166 } 167 168 fAddress = address; 169 fPort = port; 170 171 return B_OK; 172} 173 174 175status_t 176UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 177 ChainBuffer *buffer) 178{ 179 if (fUDPService == NULL) 180 return B_NO_INIT; 181 182 return fUDPService->Send(fPort, destinationAddress, destinationPort, 183 buffer); 184} 185 186 187status_t 188UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 189 const void *data, size_t size) 190{ 191 if (data == NULL) 192 return B_BAD_VALUE; 193 194 ChainBuffer buffer((void*)data, size); 195 return Send(destinationAddress, destinationPort, &buffer); 196} 197 198 199status_t 200UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout) 201{ 202 if (fUDPService == NULL) 203 return B_NO_INIT; 204 205 if (_packet == NULL) 206 return B_BAD_VALUE; 207 208 bigtime_t startTime = system_time(); 209 for (;;) { 210 fUDPService->ProcessIncomingPackets(); 211 *_packet = PopPacket(); 212 if (*_packet != NULL) 213 return B_OK; 214 215 if (system_time() - startTime > timeout) 216 return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT); 217 } 218} 219 220 221void 222UDPSocket::PushPacket(UDPPacket *packet) 223{ 224 if (fLastPacket != NULL) 225 fLastPacket->SetNext(packet); 226 else 227 fFirstPacket = packet; 228 229 fLastPacket = packet; 230 packet->SetNext(NULL); 231} 232 233 234UDPPacket * 235UDPSocket::PopPacket() 236{ 237 if (fFirstPacket == NULL) 238 return NULL; 239 240 UDPPacket *packet = fFirstPacket; 241 fFirstPacket = packet->Next(); 242 243 if (fFirstPacket == NULL) 244 fLastPacket = NULL; 245 246 packet->SetNext(NULL); 247 return packet; 248} 249 250 251// #pragma mark - UDPService 252 253 254UDPService::UDPService(IPService *ipService) 255 : 256 IPSubService(kUDPServiceName), 257 fIPService(ipService) 258{ 259} 260 261 262UDPService::~UDPService() 263{ 264 if (fIPService != NULL) 265 fIPService->UnregisterIPSubService(this); 266} 267 268 269status_t 270UDPService::Init() 271{ 272 if (fIPService == NULL) 273 return B_BAD_VALUE; 274 if (!fIPService->RegisterIPSubService(this)) 275 return B_NO_MEMORY; 276 return B_OK; 277} 278 279 280uint8 281UDPService::IPProtocol() const 282{ 283 return IPPROTO_UDP; 284} 285 286 287void 288UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP, 289 ip_addr_t destinationIP, const void *data, size_t size) 290{ 291 TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, " 292 "%lu - %lu bytes\n", sourceIP, destinationIP, size, 293 sizeof(udp_header))); 294 295 if (data == NULL || size < sizeof(udp_header)) 296 return; 297 298 const udp_header *header = (const udp_header*)data; 299 uint16 source = ntohs(header->source); 300 uint16 destination = ntohs(header->destination); 301 uint16 length = ntohs(header->length); 302 303 // check the header 304 if (length < sizeof(udp_header) || length > size 305 || (header->checksum != 0 // 0 => checksum disabled 306 && _ChecksumData(data, length, sourceIP, destinationIP) != 0)) { 307 TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size " 308 "or checksum\n")); 309 return; 310 } 311 312 // find the target socket 313 UDPSocket *socket = _FindSocket(destinationIP, destination); 314 if (socket == NULL) 315 return; 316 317 // create a UDPPacket and queue it in the socket 318 UDPPacket *packet = new(nothrow) UDPPacket; 319 if (packet == NULL) 320 return; 321 status_t error = packet->SetTo((uint8*)data + sizeof(udp_header), 322 length - sizeof(udp_header), sourceIP, source, destinationIP, 323 destination); 324 if (error == B_OK) 325 socket->PushPacket(packet); 326 else 327 delete packet; 328} 329 330 331status_t 332UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress, 333 uint16 destinationPort, ChainBuffer *buffer) 334{ 335 TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n", 336 sourcePort, destinationAddress, destinationPort, 337 (buffer != NULL ? buffer->TotalSize() : 0))); 338 339 if (fIPService == NULL) 340 return B_NO_INIT; 341 342 if (buffer == NULL) 343 return B_BAD_VALUE; 344 345 // prepend the UDP header 346 udp_header header; 347 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 348 header.source = htons(sourcePort); 349 header.destination = htons(destinationPort); 350 header.length = htons(headerBuffer.TotalSize()); 351 352 // compute the checksum 353 header.checksum = 0; 354 header.checksum = htons(_ChecksumBuffer(&headerBuffer, 355 fIPService->IPAddress(), destinationAddress, 356 headerBuffer.TotalSize())); 357 // 0 means checksum disabled; 0xffff is equivalent in this case 358 if (header.checksum == 0) 359 header.checksum = 0xffff; 360 361 return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer); 362} 363 364 365void 366UDPService::ProcessIncomingPackets() 367{ 368 if (fIPService != NULL) 369 fIPService->ProcessIncomingPackets(); 370} 371 372 373status_t 374UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port) 375{ 376 if (socket == NULL) 377 return B_BAD_VALUE; 378 379 if (_FindSocket(address, port) != NULL) { 380 printf("UDPService::BindSocket(): address in use\n"); 381 return EADDRINUSE; 382 } 383 384 return fSockets.Add(socket); 385} 386 387 388void 389UDPService::UnbindSocket(UDPSocket *socket) 390{ 391 fSockets.Remove(socket); 392} 393 394 395uint16 396UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source, 397 ip_addr_t destination, uint16 length) 398{ 399 // The checksum is calculated over a pseudo-header plus the UDP packet. 400 // So we temporarily prepend the pseudo-header. 401 struct pseudo_header { 402 ip_addr_t source; 403 ip_addr_t destination; 404 uint8 pad; 405 uint8 protocol; 406 uint16 length; 407 } __attribute__ ((__packed__)); 408 pseudo_header header = { 409 htonl(source), 410 htonl(destination), 411 0, 412 IPPROTO_UDP, 413 htons(length) 414 }; 415 416 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 417 uint16 checksum = ip_checksum(&headerBuffer); 418 headerBuffer.DetachNext(); 419 return checksum; 420} 421 422 423uint16 424UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source, 425 ip_addr_t destination) 426{ 427 ChainBuffer buffer((void*)data, length); 428 return _ChecksumBuffer(&buffer, source, destination, length); 429} 430 431 432UDPSocket * 433UDPService::_FindSocket(ip_addr_t address, uint16 port) 434{ 435 int count = fSockets.Count(); 436 for (int i = 0; i < count; i++) { 437 UDPSocket *socket = fSockets.ElementAt(i); 438 if ((address == INADDR_ANY || socket->Address() == INADDR_ANY 439 || socket->Address() == address) 440 && port == socket->Port()) { 441 return socket; 442 } 443 } 444 445 return NULL; 446} 447