1/*
2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3 * All rights reserved. Distributed under the terms of the MIT License.
4 */
5
6
7#include <boot/net/UDP.h>
8
9#include <stdio.h>
10
11#include <KernelExport.h>
12
13#include <boot/net/ChainBuffer.h>
14#include <boot/net/NetStack.h>
15
16
17//#define TRACE_UDP
18#ifdef TRACE_UDP
19#	define TRACE(x) dprintf x
20#else
21#	define TRACE(x) ;
22#endif
23
24
25using std::nothrow;
26
27
28// #pragma mark - UDPPacket
29
30
31UDPPacket::UDPPacket()
32	:
33	fNext(NULL),
34	fData(NULL),
35	fSize(0)
36{
37}
38
39
40UDPPacket::~UDPPacket()
41{
42	free(fData);
43}
44
45
46status_t
47UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress,
48	uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort)
49{
50	if (data == NULL)
51		return B_BAD_VALUE;
52
53	// clone the data
54	fData = malloc(size);
55	if (fData == NULL)
56		return B_NO_MEMORY;
57	memcpy(fData, data, size);
58
59	fSize = size;
60	fSourceAddress = sourceAddress;
61	fDestinationAddress = destinationAddress;
62	fSourcePort = sourcePort;
63	fDestinationPort = destinationPort;
64
65	return B_OK;
66}
67
68
69UDPPacket *
70UDPPacket::Next() const
71{
72	return fNext;
73}
74
75
76void
77UDPPacket::SetNext(UDPPacket *next)
78{
79	fNext = next;
80}
81
82
83const void *
84UDPPacket::Data() const
85{
86	return fData;
87}
88
89
90size_t
91UDPPacket::DataSize() const
92{
93	return fSize;
94}
95
96
97ip_addr_t
98UDPPacket::SourceAddress() const
99{
100	return fSourceAddress;
101}
102
103
104uint16
105UDPPacket::SourcePort() const
106{
107	return fSourcePort;
108}
109
110
111ip_addr_t
112UDPPacket::DestinationAddress() const
113{
114	return fDestinationAddress;
115}
116
117
118uint16
119UDPPacket::DestinationPort() const
120{
121	return fDestinationPort;
122}
123
124
125// #pragma mark - UDPSocket
126
127
128UDPSocket::UDPSocket()
129	:
130	fUDPService(NetStack::Default()->GetUDPService()),
131	fFirstPacket(NULL),
132	fLastPacket(NULL),
133	fAddress(INADDR_ANY),
134	fPort(0)
135{
136}
137
138
139UDPSocket::~UDPSocket()
140{
141	if (fPort != 0 && fUDPService != NULL)
142		fUDPService->UnbindSocket(this);
143}
144
145
146status_t
147UDPSocket::Bind(ip_addr_t address, uint16 port)
148{
149	if (fUDPService == NULL) {
150		printf("UDPSocket::Bind(): no UDP service\n");
151		return B_NO_INIT;
152	}
153
154	if (address == INADDR_BROADCAST || port == 0) {
155		printf("UDPSocket::Bind(): broadcast IP or port 0\n");
156		return B_BAD_VALUE;
157	}
158
159	if (fPort != 0) {
160		printf("UDPSocket::Bind(): already bound\n");
161		return EALREADY;
162			// correct code?
163	}
164
165	status_t error = fUDPService->BindSocket(this, address, port);
166	if (error != B_OK) {
167		printf("UDPSocket::Bind(): service BindSocket() failed\n");
168		return error;
169	}
170
171	fAddress = address;
172	fPort = port;
173
174	return B_OK;
175}
176
177
178void
179UDPSocket::Detach()
180{
181	fUDPService = NULL;
182		// This will lead to subsequent methods returning B_NO_INIT
183}
184
185
186
187status_t
188UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
189	ChainBuffer *buffer)
190{
191	if (fUDPService == NULL)
192		return B_NO_INIT;
193
194	return fUDPService->Send(fPort, destinationAddress, destinationPort,
195		buffer);
196}
197
198
199status_t
200UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
201	const void *data, size_t size)
202{
203	if (data == NULL)
204		return B_BAD_VALUE;
205
206	ChainBuffer buffer((void*)data, size);
207	return Send(destinationAddress, destinationPort, &buffer);
208}
209
210
211status_t
212UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout)
213{
214	if (fUDPService == NULL)
215		return B_NO_INIT;
216
217	if (_packet == NULL)
218		return B_BAD_VALUE;
219
220	bigtime_t startTime = system_time();
221	for (;;) {
222		fUDPService->ProcessIncomingPackets();
223		*_packet = PopPacket();
224		if (*_packet != NULL)
225			return B_OK;
226
227		if (system_time() - startTime > timeout)
228			return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT);
229	}
230}
231
232
233void
234UDPSocket::PushPacket(UDPPacket *packet)
235{
236	if (fLastPacket != NULL)
237		fLastPacket->SetNext(packet);
238	else
239		fFirstPacket = packet;
240
241	fLastPacket = packet;
242	packet->SetNext(NULL);
243}
244
245
246UDPPacket *
247UDPSocket::PopPacket()
248{
249	if (fFirstPacket == NULL)
250		return NULL;
251
252	UDPPacket *packet = fFirstPacket;
253	fFirstPacket = packet->Next();
254
255	if (fFirstPacket == NULL)
256		fLastPacket = NULL;
257
258	packet->SetNext(NULL);
259	return packet;
260}
261
262
263// #pragma mark - UDPService
264
265
266UDPService::UDPService(IPService *ipService)
267	:
268	IPSubService(kUDPServiceName),
269	fIPService(ipService)
270{
271}
272
273
274UDPService::~UDPService()
275{
276	int count = fSockets.Count();
277	for (int i = 0; i < count; i++) {
278		UDPSocket *socket = fSockets.ElementAt(i);
279		socket->Detach();
280	}
281
282	if (fIPService != NULL)
283		fIPService->UnregisterIPSubService(this);
284}
285
286
287status_t
288UDPService::Init()
289{
290	if (fIPService == NULL)
291		return B_BAD_VALUE;
292	if (!fIPService->RegisterIPSubService(this))
293		return B_NO_MEMORY;
294	return B_OK;
295}
296
297
298uint8
299UDPService::IPProtocol() const
300{
301	return IPPROTO_UDP;
302}
303
304
305void
306UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP,
307	ip_addr_t destinationIP, const void *data, size_t size)
308{
309	TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, "
310		"%lu - %lu bytes\n", sourceIP, destinationIP, size,
311		sizeof(udp_header)));
312
313	if (data == NULL || size < sizeof(udp_header))
314		return;
315
316	const udp_header *header = (const udp_header*)data;
317	uint16 source = ntohs(header->source);
318	uint16 destination = ntohs(header->destination);
319	uint16 length = ntohs(header->length);
320
321	// check the header
322	if (length < sizeof(udp_header) || length > size
323		|| (header->checksum != 0	// 0 => checksum disabled
324			&& _ChecksumData(data, length, sourceIP, destinationIP) != 0)) {
325		TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size "
326			"or checksum\n"));
327		return;
328	}
329
330	// find the target socket
331	UDPSocket *socket = _FindSocket(destinationIP, destination);
332	if (socket == NULL)
333		return;
334
335	// create a UDPPacket and queue it in the socket
336	UDPPacket *packet = new(nothrow) UDPPacket;
337	if (packet == NULL)
338		return;
339	status_t error = packet->SetTo((uint8*)data + sizeof(udp_header),
340		length - sizeof(udp_header), sourceIP, source, destinationIP,
341		destination);
342	if (error == B_OK)
343		socket->PushPacket(packet);
344	else
345		delete packet;
346}
347
348
349status_t
350UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress,
351	uint16 destinationPort, ChainBuffer *buffer)
352{
353	TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n",
354		sourcePort, destinationAddress, destinationPort,
355		(buffer != NULL ? buffer->TotalSize() : 0)));
356
357	if (fIPService == NULL)
358		return B_NO_INIT;
359
360	if (buffer == NULL)
361		return B_BAD_VALUE;
362
363	// prepend the UDP header
364	udp_header header;
365	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
366	header.source = htons(sourcePort);
367	header.destination = htons(destinationPort);
368	header.length = htons(headerBuffer.TotalSize());
369
370	// compute the checksum
371	header.checksum = 0;
372	header.checksum = htons(_ChecksumBuffer(&headerBuffer,
373		fIPService->IPAddress(), destinationAddress,
374		headerBuffer.TotalSize()));
375	// 0 means checksum disabled; 0xffff is equivalent in this case
376	if (header.checksum == 0)
377		header.checksum = 0xffff;
378
379	return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer);
380}
381
382
383void
384UDPService::ProcessIncomingPackets()
385{
386	if (fIPService != NULL)
387		fIPService->ProcessIncomingPackets();
388}
389
390
391status_t
392UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port)
393{
394	if (socket == NULL)
395		return B_BAD_VALUE;
396
397	if (_FindSocket(address, port) != NULL) {
398		printf("UDPService::BindSocket(): address in use\n");
399		return EADDRINUSE;
400	}
401
402	return fSockets.Add(socket);
403}
404
405
406void
407UDPService::UnbindSocket(UDPSocket *socket)
408{
409	fSockets.Remove(socket);
410}
411
412
413uint16
414UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source,
415	ip_addr_t destination, uint16 length)
416{
417	// The checksum is calculated over a pseudo-header plus the UDP packet.
418	// So we temporarily prepend the pseudo-header.
419	struct pseudo_header {
420		ip_addr_t	source;
421		ip_addr_t	destination;
422		uint8		pad;
423		uint8		protocol;
424		uint16		length;
425	} __attribute__ ((__packed__));
426	pseudo_header header = {
427		htonl(source),
428		htonl(destination),
429		0,
430		IPPROTO_UDP,
431		htons(length)
432	};
433
434	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
435	uint16 checksum = ip_checksum(&headerBuffer);
436	headerBuffer.DetachNext();
437	return checksum;
438}
439
440
441uint16
442UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source,
443	ip_addr_t destination)
444{
445	ChainBuffer buffer((void*)data, length);
446	return _ChecksumBuffer(&buffer, source, destination, length);
447}
448
449
450UDPSocket *
451UDPService::_FindSocket(ip_addr_t address, uint16 port)
452{
453	int count = fSockets.Count();
454	for (int i = 0; i < count; i++) {
455		UDPSocket *socket = fSockets.ElementAt(i);
456		if ((address == INADDR_ANY || socket->Address() == INADDR_ANY
457				|| socket->Address() == address)
458			&& port == socket->Port()) {
459			return socket;
460		}
461	}
462
463	return NULL;
464}
465