1/*
2 * Copyright 2005-2007, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3 * All rights reserved. Distributed under the terms of the MIT License.
4 */
5
6#include <endian.h>
7#include <errno.h>
8#include <fcntl.h>
9#include <inttypes.h>
10#include <stdio.h>
11#include <stdlib.h>
12#include <string.h>
13#include <unistd.h>
14#include <netinet/in.h>
15#include <sys/socket.h>
16#include <sys/stat.h>
17
18#include <boot/net/RemoteDiskDefs.h>
19
20
21#if defined(__BEOS__) && !defined(__HAIKU__)
22typedef int socklen_t;
23#endif
24
25
26#if __BYTE_ORDER == __LITTLE_ENDIAN
27
28static inline
29uint64_t swap_uint64(uint64_t data)
30{
31	return ((data & 0xff) << 56)
32		| ((data & 0xff00) << 40)
33		| ((data & 0xff0000) << 24)
34		| ((data & 0xff000000) << 8)
35		| ((data >> 8) & 0xff000000)
36		| ((data >> 24) & 0xff0000)
37		| ((data >> 40) & 0xff00)
38		| ((data >> 56) & 0xff);
39}
40
41#define host_to_net64(data)	swap_uint64(data)
42#define net_to_host64(data)	swap_uint64(data)
43
44#endif
45
46#if __BYTE_ORDER == __BIG_ENDIAN
47#define host_to_net64(data)	(data)
48#define net_to_host64(data)	(data)
49#endif
50
51#undef htonll
52#undef ntohll
53#define htonll(data)	host_to_net64(data)
54#define ntohll(data)	net_to_host64(data)
55
56
57class Server {
58public:
59	Server(const char *fileName)
60		: fImagePath(fileName),
61		  fImageFD(-1),
62		  fImageSize(0),
63		  fSocket(-1)
64	{
65	}
66
67	int Run()
68	{
69		_CreateSocket();
70
71		// main server loop
72		for (;;) {
73			// receive
74			fClientAddress.sin_family = AF_INET;
75			fClientAddress.sin_port = 0;
76			fClientAddress.sin_addr.s_addr = htonl(INADDR_ANY);
77			socklen_t addrSize = sizeof(fClientAddress);
78			char buffer[2048];
79			ssize_t bytesRead = recvfrom(fSocket, buffer, sizeof(buffer), 0,
80								(sockaddr*)&fClientAddress, &addrSize);
81			// handle error
82			if (bytesRead < 0) {
83				if (errno == EINTR)
84					continue;
85				fprintf(stderr, "Error: Failed to read from socket: %s.\n",
86					strerror(errno));
87				exit(1);
88			}
89
90			// short package?
91			if (bytesRead < (ssize_t)sizeof(remote_disk_header)) {
92				fprintf(stderr, "Dropping short request package (%d bytes).\n",
93					(int)bytesRead);
94				continue;
95			}
96
97			fRequest = (remote_disk_header*)buffer;
98			fRequestSize = bytesRead;
99
100			switch (fRequest->command) {
101				case REMOTE_DISK_HELLO_REQUEST:
102					_HandleHelloRequest();
103					break;
104
105				case REMOTE_DISK_READ_REQUEST:
106					_HandleReadRequest();
107					break;
108
109				case REMOTE_DISK_WRITE_REQUEST:
110					_HandleWriteRequest();
111					break;
112
113				default:
114					fprintf(stderr, "Ignoring invalid request %d.\n",
115						(int)fRequest->command);
116					break;
117			}
118		}
119
120		return 0;
121	}
122
123private:
124	void _OpenImage(bool reopen)
125	{
126		// already open?
127		if (fImageFD >= 0) {
128			if (!reopen)
129				return;
130
131			close(fImageFD);
132			fImageFD = -1;
133			fImageSize = 0;
134		}
135
136		// open the image
137		fImageFD = open(fImagePath, O_RDWR);
138		if (fImageFD < 0) {
139			fprintf(stderr, "Error: Failed to open \"%s\": %s.\n", fImagePath,
140				strerror(errno));
141			exit(1);
142		}
143
144		// get its size
145		struct stat st;
146		if (fstat(fImageFD, &st) < 0) {
147			fprintf(stderr, "Error: Failed to stat \"%s\": %s.\n", fImagePath,
148				strerror(errno));
149			exit(1);
150		}
151		fImageSize = st.st_size;
152	}
153
154	void _CreateSocket()
155	{
156		// create a socket
157		fSocket = socket(AF_INET, SOCK_DGRAM, 0);
158		if (fSocket < 0) {
159			fprintf(stderr, "Error: Failed to create a socket: %s.",
160				strerror(errno));
161			exit(1);
162		}
163
164		// bind it to the port
165		sockaddr_in addr;
166		addr.sin_family = AF_INET;
167		addr.sin_port = htons(REMOTE_DISK_SERVER_PORT);
168		addr.sin_addr.s_addr = INADDR_ANY;
169		if (bind(fSocket, (sockaddr*)&addr, sizeof(addr)) < 0) {
170			fprintf(stderr, "Error: Failed to bind socket to port %hu: %s\n",
171				REMOTE_DISK_SERVER_PORT, strerror(errno));
172			exit(1);
173		}
174	}
175
176	void _HandleHelloRequest()
177	{
178		printf("HELLO request\n");
179
180		_OpenImage(true);
181
182		remote_disk_header reply;
183		reply.offset = htonll(fImageSize);
184
185		reply.command = REMOTE_DISK_HELLO_REPLY;
186		_SendReply(&reply, sizeof(remote_disk_header));
187	}
188
189	void _HandleReadRequest()
190	{
191		_OpenImage(false);
192
193		char buffer[2048];
194		remote_disk_header *reply = (remote_disk_header*)buffer;
195		uint64_t offset = ntohll(fRequest->offset);
196		int16_t size = ntohs(fRequest->size);
197		int16_t result = 0;
198
199		printf("READ request: offset: %" PRIu64 ", %hd bytes\n", offset, size);
200
201		if (offset < (uint64_t)fImageSize && size > 0) {
202			// always read 1024 bytes
203			size = REMOTE_DISK_BLOCK_SIZE;
204			if (offset + size > (uint64_t)fImageSize)
205				size = fImageSize - offset;
206
207			// seek to the offset
208			off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
209			if (oldOffset >= 0) {
210				// read
211				ssize_t bytesRead = read(fImageFD, reply->data, size);
212				if (bytesRead >= 0) {
213					result = bytesRead;
214				} else {
215					fprintf(stderr, "Error: Failed to read at position %" PRIu64 ": "
216						"%s.", offset, strerror(errno));
217					result = REMOTE_DISK_IO_ERROR;
218				}
219			} else {
220				fprintf(stderr, "Error: Failed to seek to position %" PRIu64 ": %s.",
221					offset, strerror(errno));
222				result = REMOTE_DISK_IO_ERROR;
223			}
224		}
225
226		// send reply
227		reply->command = REMOTE_DISK_READ_REPLY;
228		reply->offset = htonll(offset);
229		reply->size = htons(result);
230		_SendReply(reply, sizeof(*reply) + (result >= 0 ? result : 0));
231	}
232
233	void _HandleWriteRequest()
234	{
235		_OpenImage(false);
236
237		remote_disk_header reply;
238		uint64_t offset = ntohll(fRequest->offset);
239		int16_t size = ntohs(fRequest->size);
240		int16_t result = 0;
241
242		printf("WRITE request: offset: %" PRIu64 ", %hd bytes\n", offset, size);
243
244		if (size < 0
245			|| (uint32_t)size > fRequestSize - sizeof(remote_disk_header)
246			|| offset > (uint64_t)fImageSize) {
247			result = REMOTE_DISK_BAD_REQUEST;
248		} else if (offset < (uint64_t)fImageSize && size > 0) {
249			if (offset + size > (uint64_t)fImageSize)
250				size = fImageSize - offset;
251
252			// seek to the offset
253			off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
254			if (oldOffset >= 0) {
255				// write
256				ssize_t bytesWritten = write(fImageFD, fRequest->data, size);
257				if (bytesWritten >= 0) {
258					result = bytesWritten;
259				} else {
260					fprintf(stderr, "Error: Failed to write at position %" PRIu64 ": "
261						"%s.", offset, strerror(errno));
262					result = REMOTE_DISK_IO_ERROR;
263				}
264			} else {
265				fprintf(stderr, "Error: Failed to seek to position %" PRIu64 ": %s.",
266					offset, strerror(errno));
267				result = REMOTE_DISK_IO_ERROR;
268			}
269		}
270
271		// send reply
272		reply.command = REMOTE_DISK_WRITE_REPLY;
273		reply.offset = htonll(offset);
274		reply.size = htons(result);
275		_SendReply(&reply, sizeof(reply));
276	}
277
278	void _SendReply(remote_disk_header *reply, int size)
279	{
280		reply->request_id = fRequest->request_id;
281		reply->port = htons(REMOTE_DISK_SERVER_PORT);
282
283		for (;;) {
284			ssize_t bytesSent = sendto(fSocket, reply, size, 0,
285				(const sockaddr*)&fClientAddress, sizeof(fClientAddress));
286
287			if (bytesSent < 0) {
288				if (errno == EINTR)
289					continue;
290				fprintf(stderr, "Error: Failed to send reply to client: %s.\n",
291					strerror(errno));
292			}
293			break;
294		}
295	}
296
297private:
298	const char			*fImagePath;
299	int					fImageFD;
300	off_t				fImageSize;
301	int					fSocket;
302	remote_disk_header	*fRequest;
303	ssize_t				fRequestSize;
304	sockaddr_in			fClientAddress;
305};
306
307
308// main
309int
310main(int argc, const char *const *argv)
311{
312	if (argc != 2) {
313		fprintf(stderr, "Usage: %s <image path>\n", argv[0]);
314		exit(1);
315	}
316	const char *fileName = argv[1];
317
318	Server server(fileName);
319	return server.Run();
320}
321