1/*
2 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 *		Oliver Tappe, zooey@hirschkaefer.de
7 *		Hugo Santos, hugosantos@gmail.com
8 */
9
10
11#include <net_buffer.h>
12#include <net_datalink.h>
13#include <net_protocol.h>
14#include <net_stack.h>
15
16#include <lock.h>
17#include <util/AutoLock.h>
18#include <util/DoublyLinkedList.h>
19#include <util/OpenHashTable.h>
20
21#include <KernelExport.h>
22
23#include <NetBufferUtilities.h>
24#include <NetUtilities.h>
25#include <ProtocolUtilities.h>
26
27#include <algorithm>
28#include <netinet/in.h>
29#include <netinet/ip.h>
30#include <new>
31#include <stdlib.h>
32#include <string.h>
33#include <utility>
34
35
36// NOTE the locking protocol dictates that we must hold UdpDomainSupport's
37//      lock before holding a child UdpEndpoint's lock. This restriction
38//      is dictated by the receive path as blind access to the endpoint
39//      hash is required when holding the DomainSupport's lock.
40
41
42//#define TRACE_UDP
43#ifdef TRACE_UDP
44#	define TRACE_BLOCK(x) dump_block x
45// do not remove the space before ', ##args' if you want this
46// to compile with gcc 2.95
47#	define TRACE_EP(format, args...)	dprintf("UDP [%llu] %p " format "\n", \
48		system_time(), this , ##args)
49#	define TRACE_EPM(format, args...)	dprintf("UDP [%llu] " format "\n", \
50		system_time() , ##args)
51#	define TRACE_DOMAIN(format, args...)	dprintf("UDP [%llu] (%d) " format \
52		"\n", system_time(), Domain()->family , ##args)
53#else
54#	define TRACE_BLOCK(x)
55#	define TRACE_EP(args...)	do { } while (0)
56#	define TRACE_EPM(args...)	do { } while (0)
57#	define TRACE_DOMAIN(args...)	do { } while (0)
58#endif
59
60
61struct udp_header {
62	uint16 source_port;
63	uint16 destination_port;
64	uint16 udp_length;
65	uint16 udp_checksum;
66} _PACKED;
67
68
69typedef NetBufferField<uint16, offsetof(udp_header, udp_checksum)>
70	UDPChecksumField;
71
72class UdpDomainSupport;
73
74class UdpEndpoint : public net_protocol, public DatagramSocket<> {
75public:
76								UdpEndpoint(net_socket* socket);
77
78			status_t			Bind(const sockaddr* newAddr);
79			status_t			Unbind(sockaddr* newAddr);
80			status_t			Connect(const sockaddr* newAddr);
81
82			status_t			Open();
83			status_t			Close();
84			status_t			Free();
85
86			status_t			SendRoutedData(net_buffer* buffer,
87									net_route* route);
88			status_t			SendData(net_buffer* buffer);
89
90			ssize_t				BytesAvailable();
91			status_t			FetchData(size_t numBytes, uint32 flags,
92									net_buffer** _buffer);
93
94			status_t			StoreData(net_buffer* buffer);
95			status_t			DeliverData(net_buffer* buffer);
96
97			// only the domain support will change/check the Active flag so
98			// we don't really need to protect it with the socket lock.
99			bool				IsActive() const { return fActive; }
100			void				SetActive(bool newValue) { fActive = newValue; }
101
102			UdpEndpoint*&		HashTableLink() { return fLink; }
103
104			void				Dump() const;
105
106private:
107			UdpDomainSupport*	fManager;
108			bool				fActive;
109									// an active UdpEndpoint is part of the
110									// endpoint hash (and it is bound and
111									// optionally connected)
112
113			UdpEndpoint*		fLink;
114};
115
116
117class UdpDomainSupport;
118
119struct UdpHashDefinition {
120	typedef net_address_module_info ParentType;
121	typedef std::pair<const sockaddr *, const sockaddr *> KeyType;
122	typedef UdpEndpoint ValueType;
123
124	UdpHashDefinition(net_address_module_info *_module)
125		: module(_module) {}
126	UdpHashDefinition(const UdpHashDefinition& definition)
127		: module(definition.module) {}
128
129	size_t HashKey(const KeyType &key) const
130	{
131		return _Mix(module->hash_address_pair(key.first, key.second));
132	}
133
134	size_t Hash(UdpEndpoint *endpoint) const
135	{
136		return _Mix(endpoint->LocalAddress().HashPair(
137			*endpoint->PeerAddress()));
138	}
139
140	static size_t _Mix(size_t hash)
141	{
142		// move the bits into the relevant range (as defined by kNumHashBuckets)
143		return (hash & 0x000007FF) ^ (hash & 0x003FF800) >> 11
144			^ (hash & 0xFFC00000UL) >> 22;
145	}
146
147	bool Compare(const KeyType &key, UdpEndpoint *endpoint) const
148	{
149		return endpoint->LocalAddress().EqualTo(key.first, true)
150			&& endpoint->PeerAddress().EqualTo(key.second, true);
151	}
152
153	UdpEndpoint *&GetLink(UdpEndpoint *endpoint) const
154	{
155		return endpoint->HashTableLink();
156	}
157
158	net_address_module_info *module;
159};
160
161
162class UdpDomainSupport : public DoublyLinkedListLinkImpl<UdpDomainSupport> {
163public:
164	UdpDomainSupport(net_domain *domain);
165	~UdpDomainSupport();
166
167	status_t Init();
168
169	net_domain *Domain() const { return fDomain; }
170
171	void Ref() { fEndpointCount++; }
172	bool Put() { fEndpointCount--; return fEndpointCount == 0; }
173
174	status_t DemuxIncomingBuffer(net_buffer* buffer);
175	status_t DeliverError(status_t error, net_buffer* buffer);
176
177	status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
178	status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
179	status_t UnbindEndpoint(UdpEndpoint *endpoint);
180
181	void DumpEndpoints() const;
182
183private:
184	status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
185	status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address);
186	status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address);
187	status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address);
188
189	UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress,
190		const sockaddr *peerAddress, uint32 index = 0);
191	status_t _DemuxBroadcast(net_buffer *buffer);
192	status_t _DemuxUnicast(net_buffer *buffer);
193
194	uint16 _GetNextEphemeral();
195	UdpEndpoint *_EndpointWithPort(uint16 port) const;
196
197	net_address_module_info *AddressModule() const
198		{ return fDomain->address_module; }
199
200	typedef BOpenHashTable<UdpHashDefinition, false> EndpointTable;
201
202	mutex			fLock;
203	net_domain		*fDomain;
204	uint16			fLastUsedEphemeral;
205	EndpointTable	fActiveEndpoints;
206	uint32			fEndpointCount;
207
208	static const uint16		kFirst = 49152;
209	static const uint16		kLast = 65535;
210	static const uint32		kNumHashBuckets = 0x800;
211							// if you change this, adjust the shifting in
212							// Hash() accordingly!
213};
214
215
216typedef DoublyLinkedList<UdpDomainSupport> UdpDomainList;
217
218
219class UdpEndpointManager {
220public:
221								UdpEndpointManager();
222								~UdpEndpointManager();
223
224			status_t			InitCheck() const;
225
226			status_t			ReceiveData(net_buffer* buffer);
227			status_t			ReceiveError(status_t error,
228									net_buffer* buffer);
229			status_t			Deframe(net_buffer* buffer);
230
231			UdpDomainSupport*	OpenEndpoint(UdpEndpoint* endpoint);
232			status_t			FreeEndpoint(UdpDomainSupport* domain);
233
234	static	int					DumpEndpoints(int argc, char *argv[]);
235
236private:
237	inline	net_domain*			_GetDomain(net_buffer* buffer);
238			UdpDomainSupport*	_GetDomainSupport(net_domain* domain,
239									bool create);
240			UdpDomainSupport*	_GetDomainSupport(net_buffer* buffer);
241
242			mutex				fLock;
243			status_t			fStatus;
244			UdpDomainList		fDomains;
245};
246
247
248static UdpEndpointManager *sUdpEndpointManager;
249
250net_buffer_module_info *gBufferModule;
251net_datalink_module_info *gDatalinkModule;
252net_stack_module_info *gStackModule;
253net_socket_module_info *gSocketModule;
254
255
256// #pragma mark -
257
258
259UdpDomainSupport::UdpDomainSupport(net_domain *domain)
260	:
261	fDomain(domain),
262	fActiveEndpoints(domain->address_module),
263	fEndpointCount(0)
264{
265	mutex_init(&fLock, "udp domain");
266
267	fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst);
268}
269
270
271UdpDomainSupport::~UdpDomainSupport()
272{
273	mutex_destroy(&fLock);
274}
275
276
277status_t
278UdpDomainSupport::Init()
279{
280	return fActiveEndpoints.Init(kNumHashBuckets);
281}
282
283
284status_t
285UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer)
286{
287	// NOTE: multicast is delivered directly to the endpoint
288	MutexLocker _(fLock);
289
290	if ((buffer->flags & MSG_BCAST) != 0)
291		return _DemuxBroadcast(buffer);
292	if ((buffer->flags & MSG_MCAST) != 0)
293		return B_ERROR;
294
295	return _DemuxUnicast(buffer);
296}
297
298
299status_t
300UdpDomainSupport::DeliverError(status_t error, net_buffer* buffer)
301{
302	if ((buffer->flags & (MSG_BCAST | MSG_MCAST)) != 0)
303		return B_ERROR;
304
305	MutexLocker _(fLock);
306
307	// Forward the error to the socket
308	UdpEndpoint* endpoint = _FindActiveEndpoint(buffer->source,
309		buffer->destination);
310	if (endpoint != NULL) {
311		gSocketModule->notify(endpoint->Socket(), B_SELECT_ERROR, error);
312		endpoint->NotifyOne();
313	}
314
315	gBufferModule->free(buffer);
316	return B_OK;
317}
318
319
320status_t
321UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint,
322	const sockaddr *address)
323{
324	if (!AddressModule()->is_same_family(address))
325		return EAFNOSUPPORT;
326
327	MutexLocker _(fLock);
328
329	if (endpoint->IsActive())
330		return EINVAL;
331
332	return _BindEndpoint(endpoint, address);
333}
334
335
336status_t
337UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint,
338	const sockaddr *address)
339{
340	MutexLocker _(fLock);
341
342	if (endpoint->IsActive()) {
343		fActiveEndpoints.Remove(endpoint);
344		endpoint->SetActive(false);
345	}
346
347	if (address->sa_family == AF_UNSPEC) {
348		// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
349		// so we reset the peer address:
350		endpoint->PeerAddress().SetToEmpty();
351	} else {
352		if (!AddressModule()->is_same_family(address))
353			return EAFNOSUPPORT;
354
355		// consider destination address INADDR_ANY as INADDR_LOOPBACK
356		sockaddr_storage _address;
357		if (AddressModule()->is_empty_address(address, false)) {
358			AddressModule()->get_loopback_address((sockaddr *)&_address);
359			// for IPv4 and IPv6 the port is at the same offset
360			((sockaddr_in&)_address).sin_port
361				= ((sockaddr_in *)address)->sin_port;
362			address = (sockaddr *)&_address;
363		}
364
365		status_t status = endpoint->PeerAddress().SetTo(address);
366		if (status < B_OK)
367			return status;
368		struct net_route *routeToDestination
369			= gDatalinkModule->get_route(fDomain, address);
370		if (routeToDestination) {
371			status = endpoint->LocalAddress().SetTo(
372				routeToDestination->interface_address->local);
373			gDatalinkModule->put_route(fDomain, routeToDestination);
374			if (status < B_OK)
375				return status;
376		}
377	}
378
379	// we need to activate no matter whether or not we have just disconnected,
380	// as calling connect() always triggers an implicit bind():
381	return _BindEndpoint(endpoint, *endpoint->LocalAddress());
382}
383
384
385status_t
386UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint)
387{
388	MutexLocker _(fLock);
389
390	if (endpoint->IsActive())
391		fActiveEndpoints.Remove(endpoint);
392
393	endpoint->SetActive(false);
394
395	return B_OK;
396}
397
398
399void
400UdpDomainSupport::DumpEndpoints() const
401{
402	kprintf("-------- UDP Domain %p ---------\n", this);
403	kprintf("%10s %20s %20s %8s\n", "address", "local", "peer", "recv-q");
404
405	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
406
407	while (UdpEndpoint* endpoint = it.Next()) {
408		endpoint->Dump();
409	}
410}
411
412
413status_t
414UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint,
415	const sockaddr *address)
416{
417	if (AddressModule()->get_port(address) == 0)
418		return _BindToEphemeral(endpoint, address);
419
420	return _Bind(endpoint, address);
421}
422
423
424status_t
425UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address)
426{
427	int socketOptions = endpoint->Socket()->options;
428
429	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
430
431	// Iterate over all active UDP-endpoints and check if the requested bind
432	// is allowed (see figure 22.24 in [Stevens - TCP2, p735]):
433	TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain,
434		address, true).Data());
435
436	while (it.HasNext()) {
437		UdpEndpoint *otherEndpoint = it.Next();
438
439		TRACE_DOMAIN("  ...checking endpoint %p (port=%u)...", otherEndpoint,
440			ntohs(otherEndpoint->LocalAddress().Port()));
441
442		if (otherEndpoint->LocalAddress().EqualPorts(address)) {
443			// port is already bound, SO_REUSEADDR or SO_REUSEPORT is required:
444			if ((otherEndpoint->Socket()->options
445					& (SO_REUSEADDR | SO_REUSEPORT)) == 0
446				|| (socketOptions & (SO_REUSEADDR | SO_REUSEPORT)) == 0)
447				return EADDRINUSE;
448
449			// if both addresses are the same, SO_REUSEPORT is required:
450			if (otherEndpoint->LocalAddress().EqualTo(address, false)
451				&& ((otherEndpoint->Socket()->options & SO_REUSEPORT) == 0
452					|| (socketOptions & SO_REUSEPORT) == 0))
453				return EADDRINUSE;
454		}
455	}
456
457	return _FinishBind(endpoint, address);
458}
459
460
461status_t
462UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint,
463	const sockaddr *address)
464{
465	SocketAddressStorage newAddress(AddressModule());
466	status_t status = newAddress.SetTo(address);
467	if (status < B_OK)
468		return status;
469
470	uint16 allocedPort = _GetNextEphemeral();
471	if (allocedPort == 0)
472		return ENOBUFS;
473
474	newAddress.SetPort(htons(allocedPort));
475
476	return _FinishBind(endpoint, *newAddress);
477}
478
479
480status_t
481UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address)
482{
483	status_t status = endpoint->next->module->bind(endpoint->next, address);
484	if (status < B_OK)
485		return status;
486
487	fActiveEndpoints.Insert(endpoint);
488	endpoint->SetActive(true);
489
490	return B_OK;
491}
492
493
494UdpEndpoint *
495UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress,
496	const sockaddr *peerAddress, uint32 index)
497{
498	ASSERT_LOCKED_MUTEX(&fLock);
499
500	TRACE_DOMAIN("finding Endpoint for %s <- %s",
501		AddressString(fDomain, ourAddress, true).Data(),
502		AddressString(fDomain, peerAddress, true).Data());
503
504	UdpEndpoint* endpoint = fActiveEndpoints.Lookup(
505		std::make_pair(ourAddress, peerAddress));
506
507	// Make sure the bound_to_device constraint is fulfilled
508	while (endpoint != NULL && endpoint->socket->bound_to_device != 0
509		&& index != 0 && endpoint->socket->bound_to_device != index) {
510		endpoint = endpoint->HashTableLink();
511		if (endpoint != NULL
512			&& (!endpoint->LocalAddress().EqualTo(ourAddress, true)
513				|| !endpoint->PeerAddress().EqualTo(peerAddress, true)))
514			return NULL;
515	}
516
517	return endpoint;
518}
519
520
521status_t
522UdpDomainSupport::_DemuxBroadcast(net_buffer* buffer)
523{
524	sockaddr* peerAddr = buffer->source;
525	sockaddr* broadcastAddr = buffer->destination;
526	uint16 incomingPort = AddressModule()->get_port(broadcastAddr);
527
528	sockaddr* mask = NULL;
529	if (buffer->interface_address != NULL)
530		mask = (sockaddr*)buffer->interface_address->mask;
531
532	TRACE_DOMAIN("_DemuxBroadcast(%p): mask %p\n", buffer, mask);
533
534	EndpointTable::Iterator iterator = fActiveEndpoints.GetIterator();
535
536	while (UdpEndpoint* endpoint = iterator.Next()) {
537		TRACE_DOMAIN("  _DemuxBroadcast(): checking endpoint %s...",
538			AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
539
540		if (endpoint->socket->bound_to_device != 0
541			&& buffer->index != endpoint->socket->bound_to_device)
542			continue;
543
544		if (endpoint->LocalAddress().Port() != incomingPort) {
545			// ports don't match, so we do not dispatch to this endpoint...
546			continue;
547		}
548
549		if (!endpoint->PeerAddress().IsEmpty(true)) {
550			// endpoint is connected to a specific destination, we check if
551			// this datagram is from there:
552			if (!endpoint->PeerAddress().EqualTo(peerAddr, true)) {
553				// no, datagram is from another peer, so we do not dispatch to
554				// this endpoint...
555				continue;
556			}
557		}
558
559		if (endpoint->LocalAddress().MatchMasked(broadcastAddr, mask)
560			|| mask == NULL || endpoint->LocalAddress().IsEmpty(false)) {
561			// address matches, dispatch to this endpoint:
562			endpoint->StoreData(buffer);
563		}
564	}
565
566	return B_OK;
567}
568
569
570status_t
571UdpDomainSupport::_DemuxUnicast(net_buffer* buffer)
572{
573	TRACE_DOMAIN("_DemuxUnicast(%p)", buffer);
574
575	const sockaddr* localAddress = buffer->destination;
576	const sockaddr* peerAddress = buffer->source;
577
578	// look for full (most special) match:
579	UdpEndpoint* endpoint = _FindActiveEndpoint(localAddress, peerAddress,
580		buffer->index);
581	if (endpoint == NULL) {
582		// look for endpoint matching local address & port:
583		endpoint = _FindActiveEndpoint(localAddress, NULL, buffer->index);
584		if (endpoint == NULL) {
585			// look for endpoint matching peer address & port and local port:
586			SocketAddressStorage local(AddressModule());
587			local.SetToEmpty();
588			local.SetPort(AddressModule()->get_port(localAddress));
589			endpoint = _FindActiveEndpoint(*local, peerAddress, buffer->index);
590			if (endpoint == NULL) {
591				// last chance: look for endpoint matching local port only:
592				endpoint = _FindActiveEndpoint(*local, NULL, buffer->index);
593			}
594		}
595	}
596
597	if (endpoint == NULL) {
598		TRACE_DOMAIN("_DemuxUnicast(%p) - no matching endpoint found!", buffer);
599		return B_NAME_NOT_FOUND;
600	}
601
602	endpoint->StoreData(buffer);
603	return B_OK;
604}
605
606
607uint16
608UdpDomainSupport::_GetNextEphemeral()
609{
610	uint16 stop, curr;
611	if (fLastUsedEphemeral < kLast) {
612		stop = fLastUsedEphemeral;
613		curr = fLastUsedEphemeral + 1;
614	} else {
615		stop = kLast;
616		curr = kFirst;
617	}
618
619	TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu",
620		fLastUsedEphemeral, curr, stop);
621
622	// TODO: a free list could be used to avoid the impact of these two
623	//        nested loops most of the time... let's see how bad this really is
624	for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) {
625		TRACE_DOMAIN("  _GetNextEphemeral(): trying port %hu...", curr);
626
627		if (_EndpointWithPort(htons(curr)) == NULL) {
628			TRACE_DOMAIN("  _GetNextEphemeral(): ...using port %hu", curr);
629			fLastUsedEphemeral = curr;
630			return curr;
631		}
632	}
633
634	return 0;
635}
636
637
638UdpEndpoint *
639UdpDomainSupport::_EndpointWithPort(uint16 port) const
640{
641	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
642
643	while (it.HasNext()) {
644		UdpEndpoint *endpoint = it.Next();
645		if (endpoint->LocalAddress().Port() == port)
646			return endpoint;
647	}
648
649	return NULL;
650}
651
652
653// #pragma mark -
654
655
656UdpEndpointManager::UdpEndpointManager()
657{
658	mutex_init(&fLock, "UDP endpoints");
659	fStatus = B_OK;
660}
661
662
663UdpEndpointManager::~UdpEndpointManager()
664{
665	mutex_destroy(&fLock);
666}
667
668
669status_t
670UdpEndpointManager::InitCheck() const
671{
672	return fStatus;
673}
674
675
676int
677UdpEndpointManager::DumpEndpoints(int argc, char *argv[])
678{
679	UdpDomainList::Iterator it = sUdpEndpointManager->fDomains.GetIterator();
680
681	while (it.HasNext())
682		it.Next()->DumpEndpoints();
683
684	return 0;
685}
686
687
688// #pragma mark - inbound
689
690
691status_t
692UdpEndpointManager::ReceiveData(net_buffer *buffer)
693{
694	TRACE_EPM("ReceiveData(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
695
696	UdpDomainSupport* domainSupport = _GetDomainSupport(buffer);
697	if (domainSupport == NULL) {
698		// we don't instantiate domain supports in the receiving path, as
699		// we are only interested in delivering data to existing sockets.
700		return B_ERROR;
701	}
702
703	status_t status = Deframe(buffer);
704	if (status != B_OK)
705		return status;
706
707	status = domainSupport->DemuxIncomingBuffer(buffer);
708	if (status != B_OK) {
709		TRACE_EPM("  ReceiveData(): no endpoint.");
710		// Send port unreachable error
711		domainSupport->Domain()->module->error_reply(NULL, buffer,
712			B_NET_ERROR_UNREACH_PORT, NULL);
713		return B_ERROR;
714	}
715
716	gBufferModule->free(buffer);
717	return B_OK;
718}
719
720
721status_t
722UdpEndpointManager::ReceiveError(status_t error, net_buffer* buffer)
723{
724	TRACE_EPM("ReceiveError(code %" B_PRId32 " %p [%" B_PRIu32 " bytes])",
725		error, buffer, buffer->size);
726
727	// We only really need the port information
728	if (buffer->size < 4)
729		return B_BAD_VALUE;
730
731	UdpDomainSupport* domainSupport = _GetDomainSupport(buffer);
732	if (domainSupport == NULL) {
733		// we don't instantiate domain supports in the receiving path, as
734		// we are only interested in delivering data to existing sockets.
735		return B_ERROR;
736	}
737
738	// Deframe the buffer manually, as we usually only get 8 bytes from the
739	// original packet
740	udp_header header;
741	if (gBufferModule->read(buffer, 0, &header,
742			std::min((size_t)buffer->size, sizeof(udp_header))) != B_OK)
743		return B_BAD_VALUE;
744
745	net_domain* domain = buffer->interface_address->domain;
746	net_address_module_info* addressModule = domain->address_module;
747
748	SocketAddress source(addressModule, buffer->source);
749	SocketAddress destination(addressModule, buffer->destination);
750
751	source.SetPort(header.source_port);
752	destination.SetPort(header.destination_port);
753
754	return domainSupport->DeliverError(error, buffer);
755}
756
757
758status_t
759UdpEndpointManager::Deframe(net_buffer* buffer)
760{
761	TRACE_EPM("Deframe(%p [%ld bytes])", buffer, buffer->size);
762
763	NetBufferHeaderReader<udp_header> bufferHeader(buffer);
764	if (bufferHeader.Status() != B_OK)
765		return bufferHeader.Status();
766
767	udp_header& header = bufferHeader.Data();
768
769	net_domain* domain = _GetDomain(buffer);
770	if (domain == NULL) {
771		TRACE_EPM("  Deframe(): UDP packed dropped as there was no domain "
772			"specified (interface address %p).", buffer->interface_address);
773		return B_BAD_VALUE;
774	}
775	net_address_module_info* addressModule = domain->address_module;
776
777	SocketAddress source(addressModule, buffer->source);
778	SocketAddress destination(addressModule, buffer->destination);
779
780	source.SetPort(header.source_port);
781	destination.SetPort(header.destination_port);
782
783	TRACE_EPM("  Deframe(): data from %s to %s", source.AsString(true).Data(),
784		destination.AsString(true).Data());
785
786	uint16 udpLength = ntohs(header.udp_length);
787	if (udpLength > buffer->size) {
788		TRACE_EPM("  Deframe(): buffer is too short, expected %hu.",
789			udpLength);
790		return B_MISMATCHED_VALUES;
791	}
792
793	if (buffer->size > udpLength)
794		gBufferModule->trim(buffer, udpLength);
795
796	if (header.udp_checksum != 0) {
797		// check UDP-checksum (simulating a so-called "pseudo-header"):
798		uint16 sum = Checksum::PseudoHeader(addressModule, gBufferModule,
799			buffer, IPPROTO_UDP);
800		if (sum != 0) {
801			TRACE_EPM("  Deframe(): bad checksum 0x%hx.", sum);
802			return B_BAD_VALUE;
803		}
804	}
805
806	bufferHeader.Remove();
807		// remove UDP-header from buffer before passing it on
808
809	return B_OK;
810}
811
812
813UdpDomainSupport *
814UdpEndpointManager::OpenEndpoint(UdpEndpoint *endpoint)
815{
816	MutexLocker _(fLock);
817
818	UdpDomainSupport* domain = _GetDomainSupport(endpoint->Domain(), true);
819	if (domain)
820		domain->Ref();
821	return domain;
822}
823
824
825status_t
826UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain)
827{
828	MutexLocker _(fLock);
829
830	if (domain->Put()) {
831		fDomains.Remove(domain);
832		delete domain;
833	}
834
835	return B_OK;
836}
837
838
839// #pragma mark -
840
841
842inline net_domain*
843UdpEndpointManager::_GetDomain(net_buffer* buffer)
844{
845	if (buffer->interface_address != NULL)
846		return buffer->interface_address->domain;
847
848	return gStackModule->get_domain(buffer->destination->sa_family);
849}
850
851
852UdpDomainSupport*
853UdpEndpointManager::_GetDomainSupport(net_domain* domain, bool create)
854{
855	ASSERT_LOCKED_MUTEX(&fLock);
856
857	if (domain == NULL)
858		return NULL;
859
860	// TODO convert this into a Hashtable or install per-domain
861	//      receiver handlers that forward the requests to the
862	//      appropriate DemuxIncomingBuffer(). For instance, while
863	//      being constructed UdpDomainSupport could call
864	//      register_domain_receiving_protocol() with the right
865	//      family.
866	UdpDomainList::Iterator iterator = fDomains.GetIterator();
867	while (UdpDomainSupport* domainSupport = iterator.Next()) {
868		if (domainSupport->Domain() == domain)
869			return domainSupport;
870	}
871
872	if (!create)
873		return NULL;
874
875	UdpDomainSupport* domainSupport
876		= new (std::nothrow) UdpDomainSupport(domain);
877	if (domainSupport == NULL || domainSupport->Init() < B_OK) {
878		delete domainSupport;
879		return NULL;
880	}
881
882	fDomains.Add(domainSupport);
883	return domainSupport;
884}
885
886
887/*!	Retrieves the UdpDomainSupport object responsible for this buffer, if the
888	domain can be determined. This is only successful if the domain support is
889	already existing, ie. there must already be an endpoint for the domain.
890*/
891UdpDomainSupport*
892UdpEndpointManager::_GetDomainSupport(net_buffer* buffer)
893{
894	MutexLocker _(fLock);
895
896	return _GetDomainSupport(_GetDomain(buffer), false);
897		// TODO: we don't want to hold to the manager's lock during the
898		// whole RX path, we may not hold an endpoint's lock with the
899		// manager lock held.
900		// But we should increase the domain's refcount here.
901}
902
903
904// #pragma mark -
905
906
907UdpEndpoint::UdpEndpoint(net_socket *socket)
908	:
909	DatagramSocket<>("udp endpoint", socket),
910	fActive(false)
911{
912}
913
914
915// #pragma mark - activation
916
917
918status_t
919UdpEndpoint::Bind(const sockaddr *address)
920{
921	TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data());
922	return fManager->BindEndpoint(this, address);
923}
924
925
926status_t
927UdpEndpoint::Unbind(sockaddr *address)
928{
929	TRACE_EP("Unbind()");
930	return fManager->UnbindEndpoint(this);
931}
932
933
934status_t
935UdpEndpoint::Connect(const sockaddr *address)
936{
937	TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data());
938	return fManager->ConnectEndpoint(this, address);
939}
940
941
942status_t
943UdpEndpoint::Open()
944{
945	TRACE_EP("Open()");
946
947	AutoLocker _(fLock);
948
949	status_t status = ProtocolSocket::Open();
950	if (status < B_OK)
951		return status;
952
953	fManager = sUdpEndpointManager->OpenEndpoint(this);
954	if (fManager == NULL)
955		return EAFNOSUPPORT;
956
957	return B_OK;
958}
959
960
961status_t
962UdpEndpoint::Close()
963{
964	TRACE_EP("Close()");
965	return B_OK;
966}
967
968
969status_t
970UdpEndpoint::Free()
971{
972	TRACE_EP("Free()");
973	fManager->UnbindEndpoint(this);
974	return sUdpEndpointManager->FreeEndpoint(fManager);
975}
976
977
978// #pragma mark - outbound
979
980
981status_t
982UdpEndpoint::SendRoutedData(net_buffer *buffer, net_route *route)
983{
984	TRACE_EP("SendRoutedData(%p [%lu bytes], %p)", buffer, buffer->size, route);
985
986	if (buffer->size > (0xffff - sizeof(udp_header)))
987		return EMSGSIZE;
988
989	buffer->protocol = IPPROTO_UDP;
990
991	// add and fill UDP-specific header:
992	NetBufferPrepend<udp_header> header(buffer);
993	if (header.Status() < B_OK)
994		return header.Status();
995
996	header->source_port = AddressModule()->get_port(buffer->source);
997	header->destination_port = AddressModule()->get_port(buffer->destination);
998	header->udp_length = htons(buffer->size);
999		// the udp-header is already included in the buffer-size
1000	header->udp_checksum = 0;
1001
1002	header.Sync();
1003
1004	uint16 calculatedChecksum = Checksum::PseudoHeader(AddressModule(),
1005		gBufferModule, buffer, IPPROTO_UDP);
1006	if (calculatedChecksum == 0)
1007		calculatedChecksum = 0xffff;
1008
1009	*UDPChecksumField(buffer) = calculatedChecksum;
1010
1011	return next->module->send_routed_data(next, route, buffer);
1012}
1013
1014
1015status_t
1016UdpEndpoint::SendData(net_buffer *buffer)
1017{
1018	TRACE_EP("SendData(%p [%lu bytes])", buffer, buffer->size);
1019
1020	return gDatalinkModule->send_data(this, NULL, buffer);
1021}
1022
1023
1024// #pragma mark - inbound
1025
1026
1027ssize_t
1028UdpEndpoint::BytesAvailable()
1029{
1030	size_t bytes = AvailableData();
1031	TRACE_EP("BytesAvailable(): %lu", bytes);
1032	return bytes;
1033}
1034
1035
1036status_t
1037UdpEndpoint::FetchData(size_t numBytes, uint32 flags, net_buffer **_buffer)
1038{
1039	TRACE_EP("FetchData(%ld, 0x%lx)", numBytes, flags);
1040
1041	status_t status = Dequeue(flags, _buffer);
1042	TRACE_EP("  FetchData(): returned from fifo status: %s", strerror(status));
1043	if (status != B_OK)
1044		return status;
1045
1046	TRACE_EP("  FetchData(): returns buffer with %ld bytes", (*_buffer)->size);
1047	return B_OK;
1048}
1049
1050
1051status_t
1052UdpEndpoint::StoreData(net_buffer *buffer)
1053{
1054	TRACE_EP("StoreData(%p [%ld bytes])", buffer, buffer->size);
1055
1056	return EnqueueClone(buffer);
1057}
1058
1059
1060status_t
1061UdpEndpoint::DeliverData(net_buffer *_buffer)
1062{
1063	TRACE_EP("DeliverData(%p [%ld bytes])", _buffer, _buffer->size);
1064
1065	net_buffer *buffer = gBufferModule->clone(_buffer, false);
1066	if (buffer == NULL)
1067		return B_NO_MEMORY;
1068
1069	status_t status = sUdpEndpointManager->Deframe(buffer);
1070	if (status < B_OK) {
1071		gBufferModule->free(buffer);
1072		return status;
1073	}
1074
1075	return Enqueue(buffer);
1076}
1077
1078
1079void
1080UdpEndpoint::Dump() const
1081{
1082	char local[64];
1083	LocalAddress().AsString(local, sizeof(local), true);
1084	char peer[64];
1085	PeerAddress().AsString(peer, sizeof(peer), true);
1086
1087	kprintf("%p %20s %20s %8lu\n", this, local, peer, fCurrentBytes);
1088}
1089
1090
1091// #pragma mark - protocol interface
1092
1093
1094net_protocol *
1095udp_init_protocol(net_socket *socket)
1096{
1097	socket->protocol = IPPROTO_UDP;
1098
1099	UdpEndpoint *endpoint = new (std::nothrow) UdpEndpoint(socket);
1100	if (endpoint == NULL || endpoint->InitCheck() < B_OK) {
1101		delete endpoint;
1102		return NULL;
1103	}
1104
1105	return endpoint;
1106}
1107
1108
1109status_t
1110udp_uninit_protocol(net_protocol *protocol)
1111{
1112	delete (UdpEndpoint *)protocol;
1113	return B_OK;
1114}
1115
1116
1117status_t
1118udp_open(net_protocol *protocol)
1119{
1120	return ((UdpEndpoint *)protocol)->Open();
1121}
1122
1123
1124status_t
1125udp_close(net_protocol *protocol)
1126{
1127	return ((UdpEndpoint *)protocol)->Close();
1128}
1129
1130
1131status_t
1132udp_free(net_protocol *protocol)
1133{
1134	return ((UdpEndpoint *)protocol)->Free();
1135}
1136
1137
1138status_t
1139udp_connect(net_protocol *protocol, const struct sockaddr *address)
1140{
1141	return ((UdpEndpoint *)protocol)->Connect(address);
1142}
1143
1144
1145status_t
1146udp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
1147{
1148	return B_NOT_SUPPORTED;
1149}
1150
1151
1152status_t
1153udp_control(net_protocol *protocol, int level, int option, void *value,
1154	size_t *_length)
1155{
1156	return protocol->next->module->control(protocol->next, level, option,
1157		value, _length);
1158}
1159
1160
1161status_t
1162udp_getsockopt(net_protocol *protocol, int level, int option, void *value,
1163	int *length)
1164{
1165	return protocol->next->module->getsockopt(protocol->next, level, option,
1166		value, length);
1167}
1168
1169
1170status_t
1171udp_setsockopt(net_protocol *protocol, int level, int option,
1172	const void *value, int length)
1173{
1174	return protocol->next->module->setsockopt(protocol->next, level, option,
1175		value, length);
1176}
1177
1178
1179status_t
1180udp_bind(net_protocol *protocol, const struct sockaddr *address)
1181{
1182	return ((UdpEndpoint *)protocol)->Bind(address);
1183}
1184
1185
1186status_t
1187udp_unbind(net_protocol *protocol, struct sockaddr *address)
1188{
1189	return ((UdpEndpoint *)protocol)->Unbind(address);
1190}
1191
1192
1193status_t
1194udp_listen(net_protocol *protocol, int count)
1195{
1196	return B_NOT_SUPPORTED;
1197}
1198
1199
1200status_t
1201udp_shutdown(net_protocol *protocol, int direction)
1202{
1203	return B_NOT_SUPPORTED;
1204}
1205
1206
1207status_t
1208udp_send_routed_data(net_protocol *protocol, struct net_route *route,
1209	net_buffer *buffer)
1210{
1211	return ((UdpEndpoint *)protocol)->SendRoutedData(buffer, route);
1212}
1213
1214
1215status_t
1216udp_send_data(net_protocol *protocol, net_buffer *buffer)
1217{
1218	return ((UdpEndpoint *)protocol)->SendData(buffer);
1219}
1220
1221
1222ssize_t
1223udp_send_avail(net_protocol *protocol)
1224{
1225	return protocol->socket->send.buffer_size;
1226}
1227
1228
1229status_t
1230udp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
1231	net_buffer **_buffer)
1232{
1233	return ((UdpEndpoint *)protocol)->FetchData(numBytes, flags, _buffer);
1234}
1235
1236
1237ssize_t
1238udp_read_avail(net_protocol *protocol)
1239{
1240	return ((UdpEndpoint *)protocol)->BytesAvailable();
1241}
1242
1243
1244struct net_domain *
1245udp_get_domain(net_protocol *protocol)
1246{
1247	return protocol->next->module->get_domain(protocol->next);
1248}
1249
1250
1251size_t
1252udp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1253{
1254	return protocol->next->module->get_mtu(protocol->next, address);
1255}
1256
1257
1258status_t
1259udp_receive_data(net_buffer *buffer)
1260{
1261	return sUdpEndpointManager->ReceiveData(buffer);
1262}
1263
1264
1265status_t
1266udp_deliver_data(net_protocol *protocol, net_buffer *buffer)
1267{
1268	return ((UdpEndpoint *)protocol)->DeliverData(buffer);
1269}
1270
1271
1272status_t
1273udp_error_received(net_error error, net_buffer* buffer)
1274{
1275	status_t notifyError = B_OK;
1276
1277	switch (error) {
1278		case B_NET_ERROR_UNREACH_NET:
1279			notifyError = ENETUNREACH;
1280			break;
1281		case B_NET_ERROR_UNREACH_HOST:
1282		case B_NET_ERROR_TRANSIT_TIME_EXCEEDED:
1283			notifyError = EHOSTUNREACH;
1284			break;
1285		case B_NET_ERROR_UNREACH_PROTOCOL:
1286		case B_NET_ERROR_UNREACH_PORT:
1287			notifyError = ECONNREFUSED;
1288			break;
1289		case B_NET_ERROR_MESSAGE_SIZE:
1290			notifyError = EMSGSIZE;
1291			break;
1292		case B_NET_ERROR_PARAMETER_PROBLEM:
1293			notifyError = ENOPROTOOPT;
1294			break;
1295
1296		case B_NET_ERROR_QUENCH:
1297		default:
1298			// ignore them
1299			gBufferModule->free(buffer);
1300			return B_OK;
1301	}
1302
1303	ASSERT(notifyError != B_OK);
1304
1305	return sUdpEndpointManager->ReceiveError(notifyError, buffer);
1306}
1307
1308
1309status_t
1310udp_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
1311	net_error_data *errorData)
1312{
1313	return B_ERROR;
1314}
1315
1316
1317ssize_t
1318udp_process_ancillary_data_no_container(net_protocol *protocol,
1319	net_buffer* buffer, void *data, size_t dataSize)
1320{
1321	return protocol->next->module->process_ancillary_data_no_container(
1322		protocol, buffer, data, dataSize);
1323}
1324
1325
1326//	#pragma mark - module interface
1327
1328
1329static status_t
1330init_udp()
1331{
1332	status_t status;
1333	TRACE_EPM("init_udp()");
1334
1335	sUdpEndpointManager = new (std::nothrow) UdpEndpointManager;
1336	if (sUdpEndpointManager == NULL)
1337		return B_NO_MEMORY;
1338
1339	status = sUdpEndpointManager->InitCheck();
1340	if (status != B_OK)
1341		goto err1;
1342
1343	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1344		IPPROTO_IP,
1345		"network/protocols/udp/v1",
1346		"network/protocols/ipv4/v1",
1347		NULL);
1348	if (status < B_OK)
1349		goto err1;
1350	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1351		IPPROTO_IP,
1352		"network/protocols/udp/v1",
1353		"network/protocols/ipv6/v1",
1354		NULL);
1355	if (status < B_OK)
1356		goto err1;
1357
1358	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1359		IPPROTO_UDP,
1360		"network/protocols/udp/v1",
1361		"network/protocols/ipv4/v1",
1362		NULL);
1363	if (status < B_OK)
1364		goto err1;
1365	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1366		IPPROTO_UDP,
1367		"network/protocols/udp/v1",
1368		"network/protocols/ipv6/v1",
1369		NULL);
1370	if (status < B_OK)
1371		goto err1;
1372
1373	status = gStackModule->register_domain_receiving_protocol(AF_INET,
1374		IPPROTO_UDP, "network/protocols/udp/v1");
1375	if (status < B_OK)
1376		goto err1;
1377	status = gStackModule->register_domain_receiving_protocol(AF_INET6,
1378		IPPROTO_UDP, "network/protocols/udp/v1");
1379	if (status < B_OK)
1380		goto err1;
1381
1382	add_debugger_command("udp_endpoints", UdpEndpointManager::DumpEndpoints,
1383		"lists all open UDP endpoints");
1384
1385	return B_OK;
1386
1387err1:
1388	// TODO: shouldn't unregister the protocols here?
1389	delete sUdpEndpointManager;
1390
1391	TRACE_EPM("init_udp() fails with %lx (%s)", status, strerror(status));
1392	return status;
1393}
1394
1395
1396static status_t
1397uninit_udp()
1398{
1399	TRACE_EPM("uninit_udp()");
1400	remove_debugger_command("udp_endpoints",
1401		UdpEndpointManager::DumpEndpoints);
1402	delete sUdpEndpointManager;
1403	return B_OK;
1404}
1405
1406
1407static status_t
1408udp_std_ops(int32 op, ...)
1409{
1410	switch (op) {
1411		case B_MODULE_INIT:
1412			return init_udp();
1413
1414		case B_MODULE_UNINIT:
1415			return uninit_udp();
1416
1417		default:
1418			return B_ERROR;
1419	}
1420}
1421
1422
1423net_protocol_module_info sUDPModule = {
1424	{
1425		"network/protocols/udp/v1",
1426		0,
1427		udp_std_ops
1428	},
1429	NET_PROTOCOL_ATOMIC_MESSAGES,
1430
1431	udp_init_protocol,
1432	udp_uninit_protocol,
1433	udp_open,
1434	udp_close,
1435	udp_free,
1436	udp_connect,
1437	udp_accept,
1438	udp_control,
1439	udp_getsockopt,
1440	udp_setsockopt,
1441	udp_bind,
1442	udp_unbind,
1443	udp_listen,
1444	udp_shutdown,
1445	udp_send_data,
1446	udp_send_routed_data,
1447	udp_send_avail,
1448	udp_read_data,
1449	udp_read_avail,
1450	udp_get_domain,
1451	udp_get_mtu,
1452	udp_receive_data,
1453	udp_deliver_data,
1454	udp_error_received,
1455	udp_error_reply,
1456	NULL,		// add_ancillary_data()
1457	NULL,		// process_ancillary_data()
1458	udp_process_ancillary_data_no_container,
1459	NULL,		// send_data_no_buffer()
1460	NULL		// read_data_no_buffer()
1461};
1462
1463module_dependency module_dependencies[] = {
1464	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
1465	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
1466	{NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule},
1467	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
1468	{}
1469};
1470
1471module_info *modules[] = {
1472	(module_info *)&sUDPModule,
1473	NULL
1474};
1475