1/*
2 * Copyright 2006-2021, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 *		Oliver Tappe, zooey@hirschkaefer.de
7 *		Hugo Santos, hugosantos@gmail.com
8 */
9
10
11#include <net_buffer.h>
12#include <net_datalink.h>
13#include <net_protocol.h>
14#include <net_stack.h>
15
16#include <lock.h>
17#include <util/AutoLock.h>
18#include <util/DoublyLinkedList.h>
19#include <util/OpenHashTable.h>
20
21#include <AutoDeleter.h>
22#include <KernelExport.h>
23
24#include <NetBufferUtilities.h>
25#include <NetUtilities.h>
26#include <ProtocolUtilities.h>
27
28#include <algorithm>
29#include <netinet/in.h>
30#include <netinet/ip.h>
31#include <new>
32#include <stdlib.h>
33#include <string.h>
34#include <utility>
35
36
37// NOTE the locking protocol dictates that we must hold UdpDomainSupport's
38//      lock before holding a child UdpEndpoint's lock. This restriction
39//      is dictated by the receive path as blind access to the endpoint
40//      hash is required when holding the DomainSupport's lock.
41
42
43//#define TRACE_UDP
44#ifdef TRACE_UDP
45#	define TRACE_BLOCK(x) dump_block x
46// do not remove the space before ', ##args' if you want this
47// to compile with gcc 2.95
48#	define TRACE_EP(format, args...)	dprintf("UDP [%" B_PRIu64 ",%" \
49		B_PRIu32 "] %p " format "\n", system_time(), \
50		thread_get_current_thread_id(), this , ##args)
51#	define TRACE_EPM(format, args...)	dprintf("UDP [%" B_PRIu64 ",%" \
52		B_PRIu32 "] " format "\n", system_time() , \
53		thread_get_current_thread_id() , ##args)
54#	define TRACE_DOMAIN(format, args...)	dprintf("UDP [%" B_PRIu64 ",%" \
55		B_PRIu32 "] (%d) " format "\n", system_time(), \
56		thread_get_current_thread_id(), Domain()->family , ##args)
57#else
58#	define TRACE_BLOCK(x)
59#	define TRACE_EP(args...)	do { } while (0)
60#	define TRACE_EPM(args...)	do { } while (0)
61#	define TRACE_DOMAIN(args...)	do { } while (0)
62#endif
63
64
65struct udp_header {
66	uint16 source_port;
67	uint16 destination_port;
68	uint16 udp_length;
69	uint16 udp_checksum;
70} _PACKED;
71
72
73typedef NetBufferField<uint16, offsetof(udp_header, udp_checksum)>
74	UDPChecksumField;
75
76class UdpDomainSupport;
77
78class UdpEndpoint : public net_protocol, public DatagramSocket<> {
79public:
80								UdpEndpoint(net_socket* socket);
81
82			status_t			Bind(const sockaddr* newAddr);
83			status_t			Unbind(sockaddr* newAddr);
84			status_t			Connect(const sockaddr* newAddr);
85
86			status_t			Open();
87			status_t			Close();
88			status_t			Free();
89
90			status_t			SendRoutedData(net_buffer* buffer,
91									net_route* route);
92			status_t			SendData(net_buffer* buffer);
93
94			ssize_t				BytesAvailable();
95			status_t			FetchData(size_t numBytes, uint32 flags,
96									net_buffer** _buffer);
97
98			status_t			StoreData(net_buffer* buffer);
99			status_t			DeliverData(net_buffer* buffer);
100
101			// only the domain support will change/check the Active flag so
102			// we don't really need to protect it with the socket lock.
103			bool				IsActive() const { return fActive; }
104			void				SetActive(bool newValue) { fActive = newValue; }
105
106			UdpEndpoint*&		HashTableLink() { return fLink; }
107
108			void				Dump() const;
109
110private:
111			UdpDomainSupport*	fManager;
112			bool				fActive;
113									// an active UdpEndpoint is part of the
114									// endpoint hash (and it is bound and
115									// optionally connected)
116
117			UdpEndpoint*		fLink;
118};
119
120
121class UdpDomainSupport;
122
123struct UdpHashDefinition {
124	typedef net_address_module_info ParentType;
125	typedef std::pair<const sockaddr *, const sockaddr *> KeyType;
126	typedef UdpEndpoint ValueType;
127
128	UdpHashDefinition(net_address_module_info *_module)
129		: module(_module) {}
130	UdpHashDefinition(const UdpHashDefinition& definition)
131		: module(definition.module) {}
132
133	size_t HashKey(const KeyType &key) const
134	{
135		return _Mix(module->hash_address_pair(key.first, key.second));
136	}
137
138	size_t Hash(UdpEndpoint *endpoint) const
139	{
140		return _Mix(endpoint->LocalAddress().HashPair(
141			*endpoint->PeerAddress()));
142	}
143
144	static size_t _Mix(size_t hash)
145	{
146		// move the bits into the relevant range (as defined by kNumHashBuckets)
147		return (hash & 0x000007FF) ^ (hash & 0x003FF800) >> 11
148			^ (hash & 0xFFC00000UL) >> 22;
149	}
150
151	bool Compare(const KeyType &key, UdpEndpoint *endpoint) const
152	{
153		return endpoint->LocalAddress().EqualTo(key.first, true)
154			&& endpoint->PeerAddress().EqualTo(key.second, true);
155	}
156
157	UdpEndpoint *&GetLink(UdpEndpoint *endpoint) const
158	{
159		return endpoint->HashTableLink();
160	}
161
162	net_address_module_info *module;
163};
164
165
166class UdpDomainSupport : public DoublyLinkedListLinkImpl<UdpDomainSupport> {
167public:
168	UdpDomainSupport(net_domain *domain);
169	~UdpDomainSupport();
170
171	status_t Init();
172
173	net_domain *Domain() const { return fDomain; }
174
175	void Ref() { fEndpointCount++; }
176	bool Put() { fEndpointCount--; return fEndpointCount == 0; }
177
178	status_t DemuxIncomingBuffer(net_buffer* buffer);
179	status_t DeliverError(status_t error, net_buffer* buffer);
180
181	status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
182	status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
183	status_t UnbindEndpoint(UdpEndpoint *endpoint);
184
185	void DumpEndpoints() const;
186
187private:
188	status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
189	status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address);
190	status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address);
191	status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address);
192
193	UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress,
194		const sockaddr *peerAddress, uint32 index = 0);
195	status_t _DemuxBroadcast(net_buffer *buffer);
196	status_t _DemuxUnicast(net_buffer *buffer);
197
198	uint16 _GetNextEphemeral();
199	UdpEndpoint *_EndpointWithPort(uint16 port) const;
200
201	net_address_module_info *AddressModule() const
202		{ return fDomain->address_module; }
203
204	typedef BOpenHashTable<UdpHashDefinition, false> EndpointTable;
205
206	mutex			fLock;
207	net_domain		*fDomain;
208	uint16			fLastUsedEphemeral;
209	EndpointTable	fActiveEndpoints;
210	uint32			fEndpointCount;
211
212	static const uint16		kFirst = 49152;
213	static const uint16		kLast = 65535;
214	static const uint32		kNumHashBuckets = 0x800;
215							// if you change this, adjust the shifting in
216							// Hash() accordingly!
217};
218
219
220typedef DoublyLinkedList<UdpDomainSupport> UdpDomainList;
221
222
223class UdpEndpointManager {
224public:
225								UdpEndpointManager();
226								~UdpEndpointManager();
227
228			status_t			InitCheck() const;
229
230			status_t			ReceiveData(net_buffer* buffer);
231			status_t			ReceiveError(status_t error,
232									net_buffer* buffer);
233			status_t			Deframe(net_buffer* buffer);
234
235			UdpDomainSupport*	OpenEndpoint(UdpEndpoint* endpoint);
236			status_t			FreeEndpoint(UdpDomainSupport* domain);
237
238	static	int					DumpEndpoints(int argc, char *argv[]);
239
240private:
241	inline	net_domain*			_GetDomain(net_buffer* buffer);
242			UdpDomainSupport*	_GetDomainSupport(net_domain* domain,
243									bool create);
244			UdpDomainSupport*	_GetDomainSupport(net_buffer* buffer);
245
246			mutex				fLock;
247			status_t			fStatus;
248			UdpDomainList		fDomains;
249};
250
251
252static UdpEndpointManager *sUdpEndpointManager;
253
254net_buffer_module_info *gBufferModule;
255net_datalink_module_info *gDatalinkModule;
256net_stack_module_info *gStackModule;
257net_socket_module_info *gSocketModule;
258
259
260// #pragma mark -
261
262
263UdpDomainSupport::UdpDomainSupport(net_domain *domain)
264	:
265	fDomain(domain),
266	fActiveEndpoints(domain->address_module),
267	fEndpointCount(0)
268{
269	mutex_init(&fLock, "udp domain");
270
271	fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst);
272}
273
274
275UdpDomainSupport::~UdpDomainSupport()
276{
277	mutex_destroy(&fLock);
278}
279
280
281status_t
282UdpDomainSupport::Init()
283{
284	return fActiveEndpoints.Init(kNumHashBuckets);
285}
286
287
288status_t
289UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer)
290{
291	// NOTE: multicast is delivered directly to the endpoint
292	MutexLocker _(fLock);
293
294	if ((buffer->flags & MSG_BCAST) != 0)
295		return _DemuxBroadcast(buffer);
296	if ((buffer->flags & MSG_MCAST) != 0)
297		return B_ERROR;
298
299	return _DemuxUnicast(buffer);
300}
301
302
303status_t
304UdpDomainSupport::DeliverError(status_t error, net_buffer* buffer)
305{
306	if ((buffer->flags & (MSG_BCAST | MSG_MCAST)) != 0)
307		return B_ERROR;
308
309	MutexLocker _(fLock);
310
311	// Forward the error to the socket
312	UdpEndpoint* endpoint = _FindActiveEndpoint(buffer->source,
313		buffer->destination);
314	if (endpoint != NULL) {
315		gSocketModule->notify(endpoint->Socket(), B_SELECT_ERROR, error);
316		endpoint->NotifyOne();
317	}
318
319	gBufferModule->free(buffer);
320	return B_OK;
321}
322
323
324status_t
325UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint,
326	const sockaddr *address)
327{
328	if (!AddressModule()->is_same_family(address))
329		return EAFNOSUPPORT;
330
331	MutexLocker _(fLock);
332
333	if (endpoint->IsActive())
334		return EINVAL;
335
336	return _BindEndpoint(endpoint, address);
337}
338
339
340status_t
341UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint,
342	const sockaddr *address)
343{
344	MutexLocker _(fLock);
345
346	if (endpoint->IsActive()) {
347		fActiveEndpoints.Remove(endpoint);
348		endpoint->SetActive(false);
349	}
350
351	if (address->sa_family == AF_UNSPEC) {
352		// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
353		// so we reset the peer address:
354		endpoint->PeerAddress().SetToEmpty();
355	} else {
356		if (!AddressModule()->is_same_family(address))
357			return EAFNOSUPPORT;
358
359		// consider destination address INADDR_ANY as INADDR_LOOPBACK
360		sockaddr_storage _address;
361		if (AddressModule()->is_empty_address(address, false)) {
362			AddressModule()->get_loopback_address((sockaddr *)&_address);
363			// for IPv4 and IPv6 the port is at the same offset
364			((sockaddr_in&)_address).sin_port
365				= ((sockaddr_in *)address)->sin_port;
366			address = (sockaddr *)&_address;
367		}
368
369		status_t status = endpoint->PeerAddress().SetTo(address);
370		if (status < B_OK)
371			return status;
372		struct net_route *routeToDestination
373			= gDatalinkModule->get_route(fDomain, address);
374		if (routeToDestination) {
375			// stay bound to current local port, if any.
376			uint16 port = endpoint->LocalAddress().Port();
377			status = endpoint->LocalAddress().SetTo(
378				routeToDestination->interface_address->local);
379			endpoint->LocalAddress().SetPort(port);
380			gDatalinkModule->put_route(fDomain, routeToDestination);
381			if (status < B_OK)
382				return status;
383		}
384	}
385
386	// we need to activate no matter whether or not we have just disconnected,
387	// as calling connect() always triggers an implicit bind():
388	status_t status = _BindEndpoint(endpoint, *endpoint->LocalAddress());
389	if (status == B_OK)
390		gSocketModule->set_connected(endpoint->Socket());
391	return status;
392}
393
394
395status_t
396UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint)
397{
398	MutexLocker _(fLock);
399
400	if (endpoint->IsActive())
401		fActiveEndpoints.Remove(endpoint);
402
403	endpoint->SetActive(false);
404
405	return B_OK;
406}
407
408
409void
410UdpDomainSupport::DumpEndpoints() const
411{
412	kprintf("-------- UDP Domain %p ---------\n", this);
413	kprintf("%10s %20s %20s %8s\n", "address", "local", "peer", "recv-q");
414
415	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
416
417	while (UdpEndpoint* endpoint = it.Next()) {
418		endpoint->Dump();
419	}
420}
421
422
423status_t
424UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint,
425	const sockaddr *address)
426{
427	if (AddressModule()->get_port(address) == 0)
428		return _BindToEphemeral(endpoint, address);
429
430	return _Bind(endpoint, address);
431}
432
433
434status_t
435UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address)
436{
437	int socketOptions = endpoint->Socket()->options;
438
439	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
440
441	// Iterate over all active UDP-endpoints and check if the requested bind
442	// is allowed (see figure 22.24 in [Stevens - TCP2, p735]):
443	TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain,
444		address, true).Data());
445
446	while (it.HasNext()) {
447		UdpEndpoint *otherEndpoint = it.Next();
448
449		TRACE_DOMAIN("  ...checking endpoint %p (port=%u)...", otherEndpoint,
450			ntohs(otherEndpoint->LocalAddress().Port()));
451
452		if (otherEndpoint->LocalAddress().EqualPorts(address)) {
453			// port is already bound, SO_REUSEADDR or SO_REUSEPORT is required:
454			if ((otherEndpoint->Socket()->options
455					& (SO_REUSEADDR | SO_REUSEPORT)) == 0
456				|| (socketOptions & (SO_REUSEADDR | SO_REUSEPORT)) == 0)
457				return EADDRINUSE;
458
459			// if both addresses are the same, SO_REUSEPORT is required:
460			if (otherEndpoint->LocalAddress().EqualTo(address, false)
461				&& ((otherEndpoint->Socket()->options & SO_REUSEPORT) == 0
462					|| (socketOptions & SO_REUSEPORT) == 0))
463				return EADDRINUSE;
464		}
465	}
466
467	return _FinishBind(endpoint, address);
468}
469
470
471status_t
472UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint,
473	const sockaddr *address)
474{
475	SocketAddressStorage newAddress(AddressModule());
476	status_t status = newAddress.SetTo(address);
477	if (status < B_OK)
478		return status;
479
480	uint16 allocedPort = _GetNextEphemeral();
481	if (allocedPort == 0)
482		return ENOBUFS;
483
484	newAddress.SetPort(htons(allocedPort));
485
486	return _FinishBind(endpoint, *newAddress);
487}
488
489
490status_t
491UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address)
492{
493	status_t status = endpoint->next->module->bind(endpoint->next, address);
494	if (status < B_OK)
495		return status;
496
497	fActiveEndpoints.Insert(endpoint);
498	endpoint->SetActive(true);
499
500	return B_OK;
501}
502
503
504UdpEndpoint *
505UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress,
506	const sockaddr *peerAddress, uint32 index)
507{
508	ASSERT_LOCKED_MUTEX(&fLock);
509
510	TRACE_DOMAIN("finding Endpoint for %s <- %s",
511		AddressString(fDomain, ourAddress, true).Data(),
512		AddressString(fDomain, peerAddress, true).Data());
513
514	UdpEndpoint* endpoint = fActiveEndpoints.Lookup(
515		std::make_pair(ourAddress, peerAddress));
516
517	// Make sure the bound_to_device constraint is fulfilled
518	while (endpoint != NULL && endpoint->socket->bound_to_device != 0
519		&& index != 0 && endpoint->socket->bound_to_device != index) {
520		endpoint = endpoint->HashTableLink();
521		if (endpoint != NULL
522			&& (!endpoint->LocalAddress().EqualTo(ourAddress, true)
523				|| !endpoint->PeerAddress().EqualTo(peerAddress, true)))
524			return NULL;
525	}
526
527	return endpoint;
528}
529
530
531status_t
532UdpDomainSupport::_DemuxBroadcast(net_buffer* buffer)
533{
534	sockaddr* peerAddr = buffer->source;
535	sockaddr* broadcastAddr = buffer->destination;
536	uint16 incomingPort = AddressModule()->get_port(broadcastAddr);
537
538	sockaddr* mask = NULL;
539	if (buffer->interface_address != NULL)
540		mask = (sockaddr*)buffer->interface_address->mask;
541
542	TRACE_DOMAIN("_DemuxBroadcast(%p): mask %p\n", buffer, mask);
543
544	EndpointTable::Iterator iterator = fActiveEndpoints.GetIterator();
545
546	while (UdpEndpoint* endpoint = iterator.Next()) {
547		TRACE_DOMAIN("  _DemuxBroadcast(): checking endpoint %s...",
548			AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
549
550		if (endpoint->socket->bound_to_device != 0
551			&& buffer->index != endpoint->socket->bound_to_device)
552			continue;
553
554		if (endpoint->LocalAddress().Port() != incomingPort) {
555			// ports don't match, so we do not dispatch to this endpoint...
556			continue;
557		}
558
559		if (!endpoint->PeerAddress().IsEmpty(true)) {
560			// endpoint is connected to a specific destination, we check if
561			// this datagram is from there:
562			if (!endpoint->PeerAddress().EqualTo(peerAddr, true)) {
563				// no, datagram is from another peer, so we do not dispatch to
564				// this endpoint...
565				continue;
566			}
567		}
568
569		if (endpoint->LocalAddress().MatchMasked(broadcastAddr, mask)
570			|| mask == NULL || endpoint->LocalAddress().IsEmpty(false)) {
571			// address matches, dispatch to this endpoint:
572			endpoint->StoreData(buffer);
573		}
574	}
575
576	return B_OK;
577}
578
579
580status_t
581UdpDomainSupport::_DemuxUnicast(net_buffer* buffer)
582{
583	TRACE_DOMAIN("_DemuxUnicast(%p)", buffer);
584
585	const sockaddr* localAddress = buffer->destination;
586	const sockaddr* peerAddress = buffer->source;
587
588	// look for full (most special) match:
589	UdpEndpoint* endpoint = _FindActiveEndpoint(localAddress, peerAddress,
590		buffer->index);
591	if (endpoint == NULL) {
592		// look for endpoint matching local address & port:
593		endpoint = _FindActiveEndpoint(localAddress, NULL, buffer->index);
594		if (endpoint == NULL) {
595			// look for endpoint matching peer address & port and local port:
596			SocketAddressStorage local(AddressModule());
597			local.SetToEmpty();
598			local.SetPort(AddressModule()->get_port(localAddress));
599			endpoint = _FindActiveEndpoint(*local, peerAddress, buffer->index);
600			if (endpoint == NULL) {
601				// last chance: look for endpoint matching local port only:
602				endpoint = _FindActiveEndpoint(*local, NULL, buffer->index);
603			}
604		}
605	}
606
607	if (endpoint == NULL) {
608		TRACE_DOMAIN("_DemuxUnicast(%p) - no matching endpoint found!", buffer);
609		return B_NAME_NOT_FOUND;
610	}
611
612	endpoint->StoreData(buffer);
613	return B_OK;
614}
615
616
617uint16
618UdpDomainSupport::_GetNextEphemeral()
619{
620	uint16 stop, curr;
621	if (fLastUsedEphemeral < kLast) {
622		stop = fLastUsedEphemeral;
623		curr = fLastUsedEphemeral + 1;
624	} else {
625		stop = kLast;
626		curr = kFirst;
627	}
628
629	TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu",
630		fLastUsedEphemeral, curr, stop);
631
632	// TODO: a free list could be used to avoid the impact of these two
633	//        nested loops most of the time... let's see how bad this really is
634	for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) {
635		TRACE_DOMAIN("  _GetNextEphemeral(): trying port %hu...", curr);
636
637		if (_EndpointWithPort(htons(curr)) == NULL) {
638			TRACE_DOMAIN("  _GetNextEphemeral(): ...using port %hu", curr);
639			fLastUsedEphemeral = curr;
640			return curr;
641		}
642	}
643
644	return 0;
645}
646
647
648UdpEndpoint *
649UdpDomainSupport::_EndpointWithPort(uint16 port) const
650{
651	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
652
653	while (it.HasNext()) {
654		UdpEndpoint *endpoint = it.Next();
655		if (endpoint->LocalAddress().Port() == port)
656			return endpoint;
657	}
658
659	return NULL;
660}
661
662
663// #pragma mark -
664
665
666UdpEndpointManager::UdpEndpointManager()
667{
668	mutex_init(&fLock, "UDP endpoints");
669	fStatus = B_OK;
670}
671
672
673UdpEndpointManager::~UdpEndpointManager()
674{
675	mutex_destroy(&fLock);
676}
677
678
679status_t
680UdpEndpointManager::InitCheck() const
681{
682	return fStatus;
683}
684
685
686int
687UdpEndpointManager::DumpEndpoints(int argc, char *argv[])
688{
689	UdpDomainList::Iterator it = sUdpEndpointManager->fDomains.GetIterator();
690
691	kprintf("===== UDP domain manager %p =====\n", sUdpEndpointManager);
692
693	while (it.HasNext())
694		it.Next()->DumpEndpoints();
695
696	return 0;
697}
698
699
700// #pragma mark - inbound
701
702
703struct DomainSupportDelete
704{
705	inline void operator()(UdpDomainSupport* object)
706	{
707		sUdpEndpointManager->FreeEndpoint(object);
708	}
709};
710
711
712struct DomainSupportDeleter
713	: BPrivate::AutoDeleter<UdpDomainSupport, DomainSupportDelete>
714{
715	DomainSupportDeleter(UdpDomainSupport* object)
716		: BPrivate::AutoDeleter<UdpDomainSupport, DomainSupportDelete>(object)
717	{}
718};
719
720
721status_t
722UdpEndpointManager::ReceiveData(net_buffer *buffer)
723{
724	TRACE_EPM("ReceiveData(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
725
726	UdpDomainSupport* domainSupport = _GetDomainSupport(buffer);
727	if (domainSupport == NULL) {
728		// we don't instantiate domain supports in the receiving path, as
729		// we are only interested in delivering data to existing sockets.
730		return B_ERROR;
731	}
732	DomainSupportDeleter deleter(domainSupport);
733
734	status_t status = Deframe(buffer);
735	if (status != B_OK) {
736		return status;
737	}
738
739	status = domainSupport->DemuxIncomingBuffer(buffer);
740	if (status != B_OK) {
741		TRACE_EPM("  ReceiveData(): no endpoint.");
742		// Send port unreachable error
743		domainSupport->Domain()->module->error_reply(NULL, buffer,
744			B_NET_ERROR_UNREACH_PORT, NULL);
745		return B_ERROR;
746	}
747
748	gBufferModule->free(buffer);
749	return B_OK;
750}
751
752
753status_t
754UdpEndpointManager::ReceiveError(status_t error, net_buffer* buffer)
755{
756	TRACE_EPM("ReceiveError(code %" B_PRId32 " %p [%" B_PRIu32 " bytes])",
757		error, buffer, buffer->size);
758
759	// We only really need the port information
760	if (buffer->size < 4)
761		return B_BAD_VALUE;
762
763	UdpDomainSupport* domainSupport = _GetDomainSupport(buffer);
764	if (domainSupport == NULL) {
765		// we don't instantiate domain supports in the receiving path, as
766		// we are only interested in delivering data to existing sockets.
767		return B_ERROR;
768	}
769	DomainSupportDeleter deleter(domainSupport);
770
771	// Deframe the buffer manually, as we usually only get 8 bytes from the
772	// original packet
773	udp_header header;
774	if (gBufferModule->read(buffer, 0, &header,
775			std::min((size_t)buffer->size, sizeof(udp_header))) != B_OK) {
776		return B_BAD_VALUE;
777	}
778
779	net_domain* domain = buffer->interface_address->domain;
780	net_address_module_info* addressModule = domain->address_module;
781
782	SocketAddress source(addressModule, buffer->source);
783	SocketAddress destination(addressModule, buffer->destination);
784
785	source.SetPort(header.source_port);
786	destination.SetPort(header.destination_port);
787
788	error = domainSupport->DeliverError(error, buffer);
789	return error;
790}
791
792
793status_t
794UdpEndpointManager::Deframe(net_buffer* buffer)
795{
796	TRACE_EPM("Deframe(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
797
798	NetBufferHeaderReader<udp_header> bufferHeader(buffer);
799	if (bufferHeader.Status() != B_OK)
800		return bufferHeader.Status();
801
802	udp_header& header = bufferHeader.Data();
803
804	net_domain* domain = _GetDomain(buffer);
805	if (domain == NULL) {
806		TRACE_EPM("  Deframe(): UDP packed dropped as there was no domain "
807			"specified (interface address %p).", buffer->interface_address);
808		return B_BAD_VALUE;
809	}
810	net_address_module_info* addressModule = domain->address_module;
811
812	SocketAddress source(addressModule, buffer->source);
813	SocketAddress destination(addressModule, buffer->destination);
814
815	source.SetPort(header.source_port);
816	destination.SetPort(header.destination_port);
817
818	TRACE_EPM("  Deframe(): data from %s to %s", source.AsString(true).Data(),
819		destination.AsString(true).Data());
820
821	uint16 udpLength = ntohs(header.udp_length);
822	if (udpLength > buffer->size) {
823		TRACE_EPM("  Deframe(): buffer is too short, expected %hu.",
824			udpLength);
825		return B_MISMATCHED_VALUES;
826	}
827
828	if (buffer->size > udpLength)
829		gBufferModule->trim(buffer, udpLength);
830
831	if (header.udp_checksum != 0) {
832		// check UDP-checksum (simulating a so-called "pseudo-header"):
833		uint16 sum = Checksum::PseudoHeader(addressModule, gBufferModule,
834			buffer, IPPROTO_UDP);
835		if (sum != 0) {
836			TRACE_EPM("  Deframe(): bad checksum 0x%hx.", sum);
837			return B_BAD_VALUE;
838		}
839	}
840
841	bufferHeader.Remove();
842		// remove UDP-header from buffer before passing it on
843
844	return B_OK;
845}
846
847
848UdpDomainSupport *
849UdpEndpointManager::OpenEndpoint(UdpEndpoint *endpoint)
850{
851	MutexLocker _(fLock);
852
853	UdpDomainSupport* domain = _GetDomainSupport(endpoint->Domain(), true);
854	return domain;
855}
856
857
858status_t
859UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain)
860{
861	MutexLocker _(fLock);
862
863	if (domain->Put()) {
864		fDomains.Remove(domain);
865		delete domain;
866	}
867
868	return B_OK;
869}
870
871
872// #pragma mark -
873
874
875inline net_domain*
876UdpEndpointManager::_GetDomain(net_buffer* buffer)
877{
878	if (buffer->interface_address != NULL)
879		return buffer->interface_address->domain;
880
881	return gStackModule->get_domain(buffer->destination->sa_family);
882}
883
884
885UdpDomainSupport*
886UdpEndpointManager::_GetDomainSupport(net_domain* domain, bool create)
887{
888	ASSERT_LOCKED_MUTEX(&fLock);
889
890	if (domain == NULL)
891		return NULL;
892
893	// TODO convert this into a Hashtable or install per-domain
894	//      receiver handlers that forward the requests to the
895	//      appropriate DemuxIncomingBuffer(). For instance, while
896	//      being constructed UdpDomainSupport could call
897	//      register_domain_receiving_protocol() with the right
898	//      family.
899	UdpDomainList::Iterator iterator = fDomains.GetIterator();
900	while (UdpDomainSupport* domainSupport = iterator.Next()) {
901		if (domainSupport->Domain() == domain) {
902			domainSupport->Ref();
903			return domainSupport;
904		}
905	}
906
907	if (!create)
908		return NULL;
909
910	UdpDomainSupport* domainSupport
911		= new (std::nothrow) UdpDomainSupport(domain);
912	if (domainSupport == NULL || domainSupport->Init() < B_OK) {
913		delete domainSupport;
914		return NULL;
915	}
916
917	fDomains.Add(domainSupport);
918	domainSupport->Ref();
919	return domainSupport;
920}
921
922
923/*!	Retrieves the UdpDomainSupport object responsible for this buffer, if the
924	domain can be determined. This is only successful if the domain support is
925	already existing, ie. there must already be an endpoint for the domain.
926*/
927UdpDomainSupport*
928UdpEndpointManager::_GetDomainSupport(net_buffer* buffer)
929{
930	MutexLocker _(fLock);
931
932	return _GetDomainSupport(_GetDomain(buffer), false);
933}
934
935
936// #pragma mark -
937
938
939UdpEndpoint::UdpEndpoint(net_socket *socket)
940	:
941	DatagramSocket<>("udp endpoint", socket),
942	fActive(false)
943{
944}
945
946
947// #pragma mark - activation
948
949
950status_t
951UdpEndpoint::Bind(const sockaddr *address)
952{
953	TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data());
954	return fManager->BindEndpoint(this, address);
955}
956
957
958status_t
959UdpEndpoint::Unbind(sockaddr *address)
960{
961	TRACE_EP("Unbind()");
962	return fManager->UnbindEndpoint(this);
963}
964
965
966status_t
967UdpEndpoint::Connect(const sockaddr *address)
968{
969	TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data());
970	return fManager->ConnectEndpoint(this, address);
971}
972
973
974status_t
975UdpEndpoint::Open()
976{
977	TRACE_EP("Open()");
978
979	AutoLocker _(fLock);
980
981	status_t status = ProtocolSocket::Open();
982	if (status < B_OK)
983		return status;
984
985	fManager = sUdpEndpointManager->OpenEndpoint(this);
986	if (fManager == NULL)
987		return EAFNOSUPPORT;
988
989	return B_OK;
990}
991
992
993status_t
994UdpEndpoint::Close()
995{
996	TRACE_EP("Close()");
997	fSocket->error = EBADF;
998	WakeAll();
999	return B_OK;
1000}
1001
1002
1003status_t
1004UdpEndpoint::Free()
1005{
1006	TRACE_EP("Free()");
1007	fManager->UnbindEndpoint(this);
1008	return sUdpEndpointManager->FreeEndpoint(fManager);
1009}
1010
1011
1012// #pragma mark - outbound
1013
1014
1015status_t
1016UdpEndpoint::SendRoutedData(net_buffer *buffer, net_route *route)
1017{
1018	TRACE_EP("SendRoutedData(%p [%" B_PRIu32 " bytes], %p)", buffer,
1019		buffer->size, route);
1020
1021	if (buffer->size > (0xffff - sizeof(udp_header)))
1022		return EMSGSIZE;
1023
1024	buffer->protocol = IPPROTO_UDP;
1025
1026	// add and fill UDP-specific header:
1027	NetBufferPrepend<udp_header> header(buffer);
1028	if (header.Status() < B_OK)
1029		return header.Status();
1030
1031	header->source_port = AddressModule()->get_port(buffer->source);
1032	header->destination_port = AddressModule()->get_port(buffer->destination);
1033	header->udp_length = htons(buffer->size);
1034		// the udp-header is already included in the buffer-size
1035	header->udp_checksum = 0;
1036
1037	header.Sync();
1038
1039	uint16 calculatedChecksum = Checksum::PseudoHeader(AddressModule(),
1040		gBufferModule, buffer, IPPROTO_UDP);
1041	if (calculatedChecksum == 0)
1042		calculatedChecksum = 0xffff;
1043
1044	*UDPChecksumField(buffer) = calculatedChecksum;
1045
1046	return next->module->send_routed_data(next, route, buffer);
1047}
1048
1049
1050status_t
1051UdpEndpoint::SendData(net_buffer *buffer)
1052{
1053	TRACE_EP("SendData(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
1054
1055	return gDatalinkModule->send_data(this, NULL, buffer);
1056}
1057
1058
1059// #pragma mark - inbound
1060
1061
1062ssize_t
1063UdpEndpoint::BytesAvailable()
1064{
1065	size_t bytes = AvailableData();
1066	TRACE_EP("BytesAvailable(): %lu", bytes);
1067	return bytes;
1068}
1069
1070
1071status_t
1072UdpEndpoint::FetchData(size_t numBytes, uint32 flags, net_buffer **_buffer)
1073{
1074	TRACE_EP("FetchData(%" B_PRIuSIZE ", 0x%" B_PRIx32 ")", numBytes, flags);
1075
1076	status_t status = Dequeue(flags, _buffer);
1077	TRACE_EP("  FetchData(): returned from fifo status: %s", strerror(status));
1078	if (status != B_OK)
1079		return status;
1080
1081	TRACE_EP("  FetchData(): returns buffer with %" B_PRIu32 " bytes",
1082		(*_buffer)->size);
1083	return B_OK;
1084}
1085
1086
1087status_t
1088UdpEndpoint::StoreData(net_buffer *buffer)
1089{
1090	TRACE_EP("StoreData(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
1091
1092	return EnqueueClone(buffer);
1093}
1094
1095
1096status_t
1097UdpEndpoint::DeliverData(net_buffer *_buffer)
1098{
1099	TRACE_EP("DeliverData(%p [%" B_PRIu32 " bytes])", _buffer, _buffer->size);
1100
1101	net_buffer *buffer = gBufferModule->clone(_buffer, false);
1102	if (buffer == NULL)
1103		return B_NO_MEMORY;
1104
1105	status_t status = sUdpEndpointManager->Deframe(buffer);
1106	if (status < B_OK) {
1107		gBufferModule->free(buffer);
1108		return status;
1109	}
1110
1111	return Enqueue(buffer);
1112}
1113
1114
1115void
1116UdpEndpoint::Dump() const
1117{
1118	char local[64];
1119	LocalAddress().AsString(local, sizeof(local), true);
1120	char peer[64];
1121	PeerAddress().AsString(peer, sizeof(peer), true);
1122
1123	kprintf("%p %20s %20s %8lu\n", this, local, peer, fCurrentBytes);
1124}
1125
1126
1127// #pragma mark - protocol interface
1128
1129
1130net_protocol *
1131udp_init_protocol(net_socket *socket)
1132{
1133	socket->protocol = IPPROTO_UDP;
1134
1135	UdpEndpoint *endpoint = new (std::nothrow) UdpEndpoint(socket);
1136	if (endpoint == NULL || endpoint->InitCheck() < B_OK) {
1137		delete endpoint;
1138		return NULL;
1139	}
1140
1141	return endpoint;
1142}
1143
1144
1145status_t
1146udp_uninit_protocol(net_protocol *protocol)
1147{
1148	delete (UdpEndpoint *)protocol;
1149	return B_OK;
1150}
1151
1152
1153status_t
1154udp_open(net_protocol *protocol)
1155{
1156	return ((UdpEndpoint *)protocol)->Open();
1157}
1158
1159
1160status_t
1161udp_close(net_protocol *protocol)
1162{
1163	return ((UdpEndpoint *)protocol)->Close();
1164}
1165
1166
1167status_t
1168udp_free(net_protocol *protocol)
1169{
1170	return ((UdpEndpoint *)protocol)->Free();
1171}
1172
1173
1174status_t
1175udp_connect(net_protocol *protocol, const struct sockaddr *address)
1176{
1177	return ((UdpEndpoint *)protocol)->Connect(address);
1178}
1179
1180
1181status_t
1182udp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
1183{
1184	return B_NOT_SUPPORTED;
1185}
1186
1187
1188status_t
1189udp_control(net_protocol *protocol, int level, int option, void *value,
1190	size_t *_length)
1191{
1192	return protocol->next->module->control(protocol->next, level, option,
1193		value, _length);
1194}
1195
1196
1197status_t
1198udp_getsockopt(net_protocol *protocol, int level, int option, void *value,
1199	int *length)
1200{
1201	return protocol->next->module->getsockopt(protocol->next, level, option,
1202		value, length);
1203}
1204
1205
1206status_t
1207udp_setsockopt(net_protocol *protocol, int level, int option,
1208	const void *value, int length)
1209{
1210	return protocol->next->module->setsockopt(protocol->next, level, option,
1211		value, length);
1212}
1213
1214
1215status_t
1216udp_bind(net_protocol *protocol, const struct sockaddr *address)
1217{
1218	return ((UdpEndpoint *)protocol)->Bind(address);
1219}
1220
1221
1222status_t
1223udp_unbind(net_protocol *protocol, struct sockaddr *address)
1224{
1225	return ((UdpEndpoint *)protocol)->Unbind(address);
1226}
1227
1228
1229status_t
1230udp_listen(net_protocol *protocol, int count)
1231{
1232	return B_NOT_SUPPORTED;
1233}
1234
1235
1236status_t
1237udp_shutdown(net_protocol *protocol, int direction)
1238{
1239	return B_NOT_SUPPORTED;
1240}
1241
1242
1243status_t
1244udp_send_routed_data(net_protocol *protocol, struct net_route *route,
1245	net_buffer *buffer)
1246{
1247	return ((UdpEndpoint *)protocol)->SendRoutedData(buffer, route);
1248}
1249
1250
1251status_t
1252udp_send_data(net_protocol *protocol, net_buffer *buffer)
1253{
1254	return ((UdpEndpoint *)protocol)->SendData(buffer);
1255}
1256
1257
1258ssize_t
1259udp_send_avail(net_protocol *protocol)
1260{
1261	return protocol->socket->send.buffer_size;
1262}
1263
1264
1265status_t
1266udp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
1267	net_buffer **_buffer)
1268{
1269	return ((UdpEndpoint *)protocol)->FetchData(numBytes, flags, _buffer);
1270}
1271
1272
1273ssize_t
1274udp_read_avail(net_protocol *protocol)
1275{
1276	return ((UdpEndpoint *)protocol)->BytesAvailable();
1277}
1278
1279
1280struct net_domain *
1281udp_get_domain(net_protocol *protocol)
1282{
1283	return protocol->next->module->get_domain(protocol->next);
1284}
1285
1286
1287size_t
1288udp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1289{
1290	return protocol->next->module->get_mtu(protocol->next, address);
1291}
1292
1293
1294status_t
1295udp_receive_data(net_buffer *buffer)
1296{
1297	return sUdpEndpointManager->ReceiveData(buffer);
1298}
1299
1300
1301status_t
1302udp_deliver_data(net_protocol *protocol, net_buffer *buffer)
1303{
1304	return ((UdpEndpoint *)protocol)->DeliverData(buffer);
1305}
1306
1307
1308status_t
1309udp_error_received(net_error error, net_buffer* buffer)
1310{
1311	status_t notifyError = B_OK;
1312
1313	switch (error) {
1314		case B_NET_ERROR_UNREACH_NET:
1315			notifyError = ENETUNREACH;
1316			break;
1317		case B_NET_ERROR_UNREACH_HOST:
1318		case B_NET_ERROR_TRANSIT_TIME_EXCEEDED:
1319			notifyError = EHOSTUNREACH;
1320			break;
1321		case B_NET_ERROR_UNREACH_PROTOCOL:
1322		case B_NET_ERROR_UNREACH_PORT:
1323			notifyError = ECONNREFUSED;
1324			break;
1325		case B_NET_ERROR_MESSAGE_SIZE:
1326			notifyError = EMSGSIZE;
1327			break;
1328		case B_NET_ERROR_PARAMETER_PROBLEM:
1329			notifyError = ENOPROTOOPT;
1330			break;
1331
1332		case B_NET_ERROR_QUENCH:
1333		default:
1334			// ignore them
1335			gBufferModule->free(buffer);
1336			return B_OK;
1337	}
1338
1339	ASSERT(notifyError != B_OK);
1340
1341	return sUdpEndpointManager->ReceiveError(notifyError, buffer);
1342}
1343
1344
1345status_t
1346udp_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
1347	net_error_data *errorData)
1348{
1349	return B_ERROR;
1350}
1351
1352
1353ssize_t
1354udp_process_ancillary_data_no_container(net_protocol *protocol,
1355	net_buffer* buffer, void *data, size_t dataSize)
1356{
1357	return protocol->next->module->process_ancillary_data_no_container(
1358		protocol, buffer, data, dataSize);
1359}
1360
1361
1362//	#pragma mark - module interface
1363
1364
1365static status_t
1366init_udp()
1367{
1368	status_t status;
1369	TRACE_EPM("init_udp()");
1370
1371	sUdpEndpointManager = new (std::nothrow) UdpEndpointManager;
1372	if (sUdpEndpointManager == NULL)
1373		return B_NO_MEMORY;
1374
1375	status = sUdpEndpointManager->InitCheck();
1376	if (status != B_OK)
1377		goto err1;
1378
1379	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1380		IPPROTO_IP,
1381		"network/protocols/udp/v1",
1382		"network/protocols/ipv4/v1",
1383		NULL);
1384	if (status < B_OK)
1385		goto err1;
1386	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1387		IPPROTO_IP,
1388		"network/protocols/udp/v1",
1389		"network/protocols/ipv6/v1",
1390		NULL);
1391	if (status < B_OK)
1392		goto err1;
1393
1394	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1395		IPPROTO_UDP,
1396		"network/protocols/udp/v1",
1397		"network/protocols/ipv4/v1",
1398		NULL);
1399	if (status < B_OK)
1400		goto err1;
1401	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1402		IPPROTO_UDP,
1403		"network/protocols/udp/v1",
1404		"network/protocols/ipv6/v1",
1405		NULL);
1406	if (status < B_OK)
1407		goto err1;
1408
1409	status = gStackModule->register_domain_receiving_protocol(AF_INET,
1410		IPPROTO_UDP, "network/protocols/udp/v1");
1411	if (status < B_OK)
1412		goto err1;
1413	status = gStackModule->register_domain_receiving_protocol(AF_INET6,
1414		IPPROTO_UDP, "network/protocols/udp/v1");
1415	if (status < B_OK)
1416		goto err1;
1417
1418	add_debugger_command("udp_endpoints", UdpEndpointManager::DumpEndpoints,
1419		"lists all open UDP endpoints");
1420
1421	return B_OK;
1422
1423err1:
1424	// TODO: shouldn't unregister the protocols here?
1425	delete sUdpEndpointManager;
1426
1427	TRACE_EPM("init_udp() fails with %" B_PRIx32 " (%s)", status,
1428		strerror(status));
1429	return status;
1430}
1431
1432
1433static status_t
1434uninit_udp()
1435{
1436	TRACE_EPM("uninit_udp()");
1437	remove_debugger_command("udp_endpoints",
1438		UdpEndpointManager::DumpEndpoints);
1439	delete sUdpEndpointManager;
1440	return B_OK;
1441}
1442
1443
1444static status_t
1445udp_std_ops(int32 op, ...)
1446{
1447	switch (op) {
1448		case B_MODULE_INIT:
1449			return init_udp();
1450
1451		case B_MODULE_UNINIT:
1452			return uninit_udp();
1453
1454		default:
1455			return B_ERROR;
1456	}
1457}
1458
1459
1460net_protocol_module_info sUDPModule = {
1461	{
1462		"network/protocols/udp/v1",
1463		0,
1464		udp_std_ops
1465	},
1466	NET_PROTOCOL_ATOMIC_MESSAGES,
1467
1468	udp_init_protocol,
1469	udp_uninit_protocol,
1470	udp_open,
1471	udp_close,
1472	udp_free,
1473	udp_connect,
1474	udp_accept,
1475	udp_control,
1476	udp_getsockopt,
1477	udp_setsockopt,
1478	udp_bind,
1479	udp_unbind,
1480	udp_listen,
1481	udp_shutdown,
1482	udp_send_data,
1483	udp_send_routed_data,
1484	udp_send_avail,
1485	udp_read_data,
1486	udp_read_avail,
1487	udp_get_domain,
1488	udp_get_mtu,
1489	udp_receive_data,
1490	udp_deliver_data,
1491	udp_error_received,
1492	udp_error_reply,
1493	NULL,		// add_ancillary_data()
1494	NULL,		// process_ancillary_data()
1495	udp_process_ancillary_data_no_container,
1496	NULL,		// send_data_no_buffer()
1497	NULL		// read_data_no_buffer()
1498};
1499
1500module_dependency module_dependencies[] = {
1501	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
1502	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
1503	{NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule},
1504	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
1505	{}
1506};
1507
1508module_info *modules[] = {
1509	(module_info *)&sUDPModule,
1510	NULL
1511};
1512