1/*
2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3 * All rights reserved. Distributed under the terms of the MIT License.
4 */
5
6#include <boot/net/IP.h>
7
8#include <stdio.h>
9#include <KernelExport.h>
10
11#include <boot/net/ARP.h>
12#include <boot/net/ChainBuffer.h>
13
14
15//#define TRACE_IP
16#ifdef TRACE_IP
17#	define TRACE(x) dprintf x
18#else
19#	define TRACE(x) ;
20#endif
21
22
23// #pragma mark - IPSubService
24
25// constructor
26IPSubService::IPSubService(const char *serviceName)
27	: NetService(serviceName)
28{
29}
30
31// destructor
32IPSubService::~IPSubService()
33{
34}
35
36
37// #pragma mark - IPService
38
39// constructor
40IPService::IPService(EthernetService *ethernet, ARPService *arpService)
41	: EthernetSubService(kIPServiceName),
42		fEthernet(ethernet),
43		fARPService(arpService)
44{
45}
46
47// destructor
48IPService::~IPService()
49{
50	if (fEthernet)
51		fEthernet->UnregisterEthernetSubService(this);
52}
53
54// Init
55status_t
56IPService::Init()
57{
58	if (!fEthernet)
59		return B_BAD_VALUE;
60	if (!fEthernet->RegisterEthernetSubService(this))
61		return B_NO_MEMORY;
62	return B_OK;
63}
64
65// IPAddress
66ip_addr_t
67IPService::IPAddress() const
68{
69	return (fEthernet ? fEthernet->IPAddress() : INADDR_ANY);
70}
71
72// EthernetProtocol
73uint16
74IPService::EthernetProtocol() const
75{
76	return ETHERTYPE_IP;
77}
78
79// HandleEthernetPacket
80void
81IPService::HandleEthernetPacket(EthernetService *ethernet,
82	const mac_addr_t &targetAddress, const void *data, size_t size)
83{
84	TRACE(("IPService::HandleEthernetPacket(): %lu - %lu bytes\n", size,
85		sizeof(ip_header)));
86
87	if (!data || size < sizeof(ip_header))
88		return;
89
90	// check header
91	const ip_header *header = (const ip_header*)data;
92	// header length OK?
93	int headerLength = header->header_length * 4;
94	if (headerLength < 20 || headerLength > (int)size
95		// IP V4?
96		|| header->version != IP_PROTOCOL_VERSION_4
97		// length OK?
98		|| ntohs(header->total_length) > size
99		// broadcast or our IP?
100		|| (header->destination != htonl(INADDR_BROADCAST)
101			&& (fEthernet->IPAddress() == INADDR_ANY
102				|| header->destination != htonl(fEthernet->IPAddress())))
103		// checksum OK?
104		|| _Checksum(*header) != 0) {
105		return;
106	}
107
108	// find a service handling this kind of packet
109	int serviceCount = fServices.Count();
110	for (int i = 0; i < serviceCount; i++) {
111		IPSubService *service = fServices.ElementAt(i);
112		if (service->IPProtocol() == header->protocol) {
113			service->HandleIPPacket(this, ntohl(header->source),
114				ntohl(header->destination),
115				(uint8*)data + headerLength,
116				ntohs(header->total_length) - headerLength);
117			break;
118		}
119	}
120}
121
122// Send
123status_t
124IPService::Send(ip_addr_t destination, uint8 protocol, ChainBuffer *buffer)
125{
126	TRACE(("IPService::Send(to: %08lx, proto: %lu, %lu bytes)\n", destination,
127		(uint32)protocol, (buffer ? buffer->TotalSize() : 0)));
128
129	if (!buffer)
130		return B_BAD_VALUE;
131
132	if (!fEthernet || !fARPService)
133		return B_NO_INIT;
134
135	// prepare header
136	ip_header header;
137	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
138	header.header_length = 5;	// 5 32 bit words, no options
139	header.version = IP_PROTOCOL_VERSION_4;
140	header.type_of_service = 0;
141	header.total_length = htons(headerBuffer.TotalSize());
142	header.identifier = 0;
143	header.fragment_offset = htons(IP_DONT_FRAGMENT);
144	header.time_to_live = IP_DEFAULT_TIME_TO_LIVE;
145	header.protocol = protocol;
146	header.checksum = 0;
147	header.source = htonl(fEthernet->IPAddress());
148	header.destination = htonl(destination);
149
150	// compute check sum
151	header.checksum = htons(_Checksum(header));
152
153	// get target MAC address
154	mac_addr_t targetMAC;
155	status_t error = fARPService->GetMACForIP(destination, targetMAC);
156	if (error != B_OK)
157		return error;
158
159	// send the packet
160	return fEthernet->Send(targetMAC, ETHERTYPE_IP, &headerBuffer);
161}
162
163// ProcessIncomingPackets
164void
165IPService::ProcessIncomingPackets()
166{
167	if (fEthernet)
168		fEthernet->ProcessIncomingPackets();
169}
170
171// RegisterIPSubService
172bool
173IPService::RegisterIPSubService(IPSubService *service)
174{
175	return (service && fServices.Add(service) == B_OK);
176}
177
178// UnregisterIPSubService
179bool
180IPService::UnregisterIPSubService(IPSubService *service)
181{
182	return (service && fServices.Remove(service) >= 0);
183}
184
185// CountSubNetServices
186int
187IPService::CountSubNetServices() const
188{
189	return fServices.Count();
190}
191
192// SubNetServiceAt
193NetService *
194IPService::SubNetServiceAt(int index) const
195{
196	return fServices.ElementAt(index);
197}
198
199// _Checksum
200uint16
201IPService::_Checksum(const ip_header &header)
202{
203	ChainBuffer buffer((void*)&header, header.header_length * 4);
204	return ip_checksum(&buffer);
205}
206
207
208// #pragma mark -
209
210// ip_checksum
211uint16
212ip_checksum(ChainBuffer *buffer)
213{
214	// ChainBuffer iterator returning a stream of uint16 (big endian).
215	struct Iterator {
216		Iterator(ChainBuffer *buffer)
217			: fBuffer(buffer),
218			  fOffset(-1)
219		{
220			_Next();
221		}
222
223		bool HasNext() const
224		{
225			return fBuffer;
226		}
227
228		uint16 Next()
229		{
230			uint16 byte = _NextByte();
231			return (byte << 8) | _NextByte();
232		}
233
234	private:
235		void _Next()
236		{
237			while (fBuffer) {
238				fOffset++;
239				if (fOffset < (int)fBuffer->Size())
240					break;
241
242				fOffset = -1;
243				fBuffer = fBuffer->Next();
244			}
245		}
246
247		uint8 _NextByte()
248		{
249			uint8 byte = (fBuffer ? ((uint8*)fBuffer->Data())[fOffset] : 0);
250			_Next();
251			return byte;
252		}
253
254		ChainBuffer	*fBuffer;
255		int			fOffset;
256	};
257
258	Iterator it(buffer);
259
260	uint32 checksum = 0;
261	while (it.HasNext()) {
262		checksum += it.Next();
263		while (checksum >> 16)
264			checksum = (checksum & 0xffff) + (checksum >> 16);
265	}
266
267	return ~checksum;
268}
269
270
271ip_addr_t
272ip_parse_address(const char *string)
273{
274	ip_addr_t address = 0;
275	int components = 0;
276
277	// TODO: Handles only IPv4 addresses for now.
278	while (components < 4) {
279		address |= strtol(string, NULL, 0) << ((4 - components - 1) * 8);
280
281		const char *dot = strchr(string, '.');
282		if (dot == NULL)
283			break;
284
285		string = dot + 1;
286		components++;
287	}
288
289	return address;
290}
291