1/*
2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3 * All rights reserved. Distributed under the terms of the MIT License.
4 */
5
6
7#include <boot/net/UDP.h>
8
9#include <stdio.h>
10
11#include <KernelExport.h>
12
13#include <boot/net/ChainBuffer.h>
14#include <boot/net/NetStack.h>
15
16
17//#define TRACE_UDP
18#ifdef TRACE_UDP
19#	define TRACE(x) dprintf x
20#else
21#	define TRACE(x) ;
22#endif
23
24
25// #pragma mark - UDPPacket
26
27
28UDPPacket::UDPPacket()
29	:
30	fNext(NULL),
31	fData(NULL),
32	fSize(0)
33{
34}
35
36
37UDPPacket::~UDPPacket()
38{
39	free(fData);
40}
41
42
43status_t
44UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress,
45	uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort)
46{
47	if (data == NULL)
48		return B_BAD_VALUE;
49
50	// clone the data
51	fData = malloc(size);
52	if (fData == NULL)
53		return B_NO_MEMORY;
54	memcpy(fData, data, size);
55
56	fSize = size;
57	fSourceAddress = sourceAddress;
58	fDestinationAddress = destinationAddress;
59	fSourcePort = sourcePort;
60	fDestinationPort = destinationPort;
61
62	return B_OK;
63}
64
65
66UDPPacket *
67UDPPacket::Next() const
68{
69	return fNext;
70}
71
72
73void
74UDPPacket::SetNext(UDPPacket *next)
75{
76	fNext = next;
77}
78
79
80const void *
81UDPPacket::Data() const
82{
83	return fData;
84}
85
86
87size_t
88UDPPacket::DataSize() const
89{
90	return fSize;
91}
92
93
94ip_addr_t
95UDPPacket::SourceAddress() const
96{
97	return fSourceAddress;
98}
99
100
101uint16
102UDPPacket::SourcePort() const
103{
104	return fSourcePort;
105}
106
107
108ip_addr_t
109UDPPacket::DestinationAddress() const
110{
111	return fDestinationAddress;
112}
113
114
115uint16
116UDPPacket::DestinationPort() const
117{
118	return fDestinationPort;
119}
120
121
122// #pragma mark - UDPSocket
123
124
125UDPSocket::UDPSocket()
126	:
127	fUDPService(NetStack::Default()->GetUDPService()),
128	fFirstPacket(NULL),
129	fLastPacket(NULL),
130	fAddress(INADDR_ANY),
131	fPort(0)
132{
133}
134
135
136UDPSocket::~UDPSocket()
137{
138	if (fPort != 0 && fUDPService != NULL)
139		fUDPService->UnbindSocket(this);
140}
141
142
143status_t
144UDPSocket::Bind(ip_addr_t address, uint16 port)
145{
146	if (fUDPService == NULL) {
147		printf("UDPSocket::Bind(): no UDP service\n");
148		return B_NO_INIT;
149	}
150
151	if (address == INADDR_BROADCAST || port == 0) {
152		printf("UDPSocket::Bind(): broadcast IP or port 0\n");
153		return B_BAD_VALUE;
154	}
155
156	if (fPort != 0) {
157		printf("UDPSocket::Bind(): already bound\n");
158		return EALREADY;
159			// correct code?
160	}
161
162	status_t error = fUDPService->BindSocket(this, address, port);
163	if (error != B_OK) {
164		printf("UDPSocket::Bind(): service BindSocket() failed\n");
165		return error;
166	}
167
168	fAddress = address;
169	fPort = port;
170
171	return B_OK;
172}
173
174
175status_t
176UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
177	ChainBuffer *buffer)
178{
179	if (fUDPService == NULL)
180		return B_NO_INIT;
181
182	return fUDPService->Send(fPort, destinationAddress, destinationPort,
183		buffer);
184}
185
186
187status_t
188UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
189	const void *data, size_t size)
190{
191	if (data == NULL)
192		return B_BAD_VALUE;
193
194	ChainBuffer buffer((void*)data, size);
195	return Send(destinationAddress, destinationPort, &buffer);
196}
197
198
199status_t
200UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout)
201{
202	if (fUDPService == NULL)
203		return B_NO_INIT;
204
205	if (_packet == NULL)
206		return B_BAD_VALUE;
207
208	bigtime_t startTime = system_time();
209	for (;;) {
210		fUDPService->ProcessIncomingPackets();
211		*_packet = PopPacket();
212		if (*_packet != NULL)
213			return B_OK;
214
215		if (system_time() - startTime > timeout)
216			return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT);
217	}
218}
219
220
221void
222UDPSocket::PushPacket(UDPPacket *packet)
223{
224	if (fLastPacket != NULL)
225		fLastPacket->SetNext(packet);
226	else
227		fFirstPacket = packet;
228
229	fLastPacket = packet;
230	packet->SetNext(NULL);
231}
232
233
234UDPPacket *
235UDPSocket::PopPacket()
236{
237	if (fFirstPacket == NULL)
238		return NULL;
239
240	UDPPacket *packet = fFirstPacket;
241	fFirstPacket = packet->Next();
242
243	if (fFirstPacket == NULL)
244		fLastPacket = NULL;
245
246	packet->SetNext(NULL);
247	return packet;
248}
249
250
251// #pragma mark - UDPService
252
253
254UDPService::UDPService(IPService *ipService)
255	:
256	IPSubService(kUDPServiceName),
257	fIPService(ipService)
258{
259}
260
261
262UDPService::~UDPService()
263{
264	if (fIPService != NULL)
265		fIPService->UnregisterIPSubService(this);
266}
267
268
269status_t
270UDPService::Init()
271{
272	if (fIPService == NULL)
273		return B_BAD_VALUE;
274	if (!fIPService->RegisterIPSubService(this))
275		return B_NO_MEMORY;
276	return B_OK;
277}
278
279
280uint8
281UDPService::IPProtocol() const
282{
283	return IPPROTO_UDP;
284}
285
286
287void
288UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP,
289	ip_addr_t destinationIP, const void *data, size_t size)
290{
291	TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, "
292		"%lu - %lu bytes\n", sourceIP, destinationIP, size,
293		sizeof(udp_header)));
294
295	if (data == NULL || size < sizeof(udp_header))
296		return;
297
298	const udp_header *header = (const udp_header*)data;
299	uint16 source = ntohs(header->source);
300	uint16 destination = ntohs(header->destination);
301	uint16 length = ntohs(header->length);
302
303	// check the header
304	if (length < sizeof(udp_header) || length > size
305		|| (header->checksum != 0	// 0 => checksum disabled
306			&& _ChecksumData(data, length, sourceIP, destinationIP) != 0)) {
307		TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size "
308			"or checksum\n"));
309		return;
310	}
311
312	// find the target socket
313	UDPSocket *socket = _FindSocket(destinationIP, destination);
314	if (socket == NULL)
315		return;
316
317	// create a UDPPacket and queue it in the socket
318	UDPPacket *packet = new(nothrow) UDPPacket;
319	if (packet == NULL)
320		return;
321	status_t error = packet->SetTo((uint8*)data + sizeof(udp_header),
322		length - sizeof(udp_header), sourceIP, source, destinationIP,
323		destination);
324	if (error == B_OK)
325		socket->PushPacket(packet);
326	else
327		delete packet;
328}
329
330
331status_t
332UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress,
333	uint16 destinationPort, ChainBuffer *buffer)
334{
335	TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n",
336		sourcePort, destinationAddress, destinationPort,
337		(buffer != NULL ? buffer->TotalSize() : 0)));
338
339	if (fIPService == NULL)
340		return B_NO_INIT;
341
342	if (buffer == NULL)
343		return B_BAD_VALUE;
344
345	// prepend the UDP header
346	udp_header header;
347	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
348	header.source = htons(sourcePort);
349	header.destination = htons(destinationPort);
350	header.length = htons(headerBuffer.TotalSize());
351
352	// compute the checksum
353	header.checksum = 0;
354	header.checksum = htons(_ChecksumBuffer(&headerBuffer,
355		fIPService->IPAddress(), destinationAddress,
356		headerBuffer.TotalSize()));
357	// 0 means checksum disabled; 0xffff is equivalent in this case
358	if (header.checksum == 0)
359		header.checksum = 0xffff;
360
361	return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer);
362}
363
364
365void
366UDPService::ProcessIncomingPackets()
367{
368	if (fIPService != NULL)
369		fIPService->ProcessIncomingPackets();
370}
371
372
373status_t
374UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port)
375{
376	if (socket == NULL)
377		return B_BAD_VALUE;
378
379	if (_FindSocket(address, port) != NULL) {
380		printf("UDPService::BindSocket(): address in use\n");
381		return EADDRINUSE;
382	}
383
384	return fSockets.Add(socket);
385}
386
387
388void
389UDPService::UnbindSocket(UDPSocket *socket)
390{
391	fSockets.Remove(socket);
392}
393
394
395uint16
396UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source,
397	ip_addr_t destination, uint16 length)
398{
399	// The checksum is calculated over a pseudo-header plus the UDP packet.
400	// So we temporarily prepend the pseudo-header.
401	struct pseudo_header {
402		ip_addr_t	source;
403		ip_addr_t	destination;
404		uint8		pad;
405		uint8		protocol;
406		uint16		length;
407	} __attribute__ ((__packed__));
408	pseudo_header header = {
409		htonl(source),
410		htonl(destination),
411		0,
412		IPPROTO_UDP,
413		htons(length)
414	};
415
416	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
417	uint16 checksum = ip_checksum(&headerBuffer);
418	headerBuffer.DetachNext();
419	return checksum;
420}
421
422
423uint16
424UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source,
425	ip_addr_t destination)
426{
427	ChainBuffer buffer((void*)data, length);
428	return _ChecksumBuffer(&buffer, source, destination, length);
429}
430
431
432UDPSocket *
433UDPService::_FindSocket(ip_addr_t address, uint16 port)
434{
435	int count = fSockets.Count();
436	for (int i = 0; i < count; i++) {
437		UDPSocket *socket = fSockets.ElementAt(i);
438		if ((address == INADDR_ANY || socket->Address() == INADDR_ANY
439				|| socket->Address() == address)
440			&& port == socket->Port()) {
441			return socket;
442		}
443	}
444
445	return NULL;
446}
447