1/*
2 * Copyright 2010-2011, Haiku, Inc. All rights reserved.
3 * Copyright 2010, Clemens Zeidler <haiku@clemens-zeidler.de>
4 * Distributed under the terms of the MIT License.
5 */
6
7
8#include "ServerConnection.h"
9
10#include <errno.h>
11#include <sys/poll.h>
12#include <unistd.h>
13
14#ifdef USE_SSL
15#	include <openssl/ssl.h>
16#	include <openssl/rand.h>
17#endif
18
19#include <Autolock.h>
20#include <Locker.h>
21#include <NetworkAddress.h>
22
23
24#define DEBUG_SERVER_CONNECTION
25#ifdef DEBUG_SERVER_CONNECTION
26#	include <stdio.h>
27#	define TRACE(x...) printf(x)
28#else
29#	define TRACE(x...) ;
30#endif
31
32
33namespace BPrivate {
34
35
36class AbstractConnection {
37public:
38	virtual						~AbstractConnection();
39
40	virtual status_t			Connect(const char* server, uint32 port) = 0;
41	virtual status_t			Disconnect() = 0;
42
43	virtual status_t			WaitForData(bigtime_t timeout) = 0;
44
45	virtual ssize_t				Read(char* buffer, uint32 length) = 0;
46	virtual ssize_t				Write(const char* buffer, uint32 length) = 0;
47};
48
49
50class SocketConnection : public AbstractConnection {
51public:
52								SocketConnection();
53
54			status_t			Connect(const char* server, uint32 port);
55			status_t			Disconnect();
56
57			status_t			WaitForData(bigtime_t timeout);
58
59			ssize_t				Read(char* buffer, uint32 length);
60			ssize_t				Write(const char* buffer, uint32 length);
61
62protected:
63			int					fSocket;
64};
65
66
67#ifdef USE_SSL
68
69
70class InitSSL {
71public:
72	InitSSL()
73	{
74		if (SSL_library_init() != 1) {
75			fInit = false;
76			return;
77		}
78
79		// use combination of more or less random this pointer and system time
80		int64 seed = (int64)this + system_time();
81		RAND_seed(&seed, sizeof(seed));
82
83		fInit = true;
84		return;
85	};
86
87
88	status_t InitCheck()
89	{
90		return fInit ? B_OK : B_ERROR;
91	}
92
93private:
94			bool				fInit;
95};
96
97
98class SSLConnection : public SocketConnection {
99public:
100								SSLConnection();
101
102			status_t			Connect(const char* server, uint32 port);
103			status_t			Disconnect();
104
105			status_t			WaitForData(bigtime_t timeout);
106
107			ssize_t				Read(char* buffer, uint32 length);
108			ssize_t				Write(const char* buffer, uint32 length);
109
110private:
111			SSL_CTX*			fCTX;
112			SSL*				fSSL;
113			BIO*				fBIO;
114};
115
116
117static InitSSL gInitSSL;
118
119
120#endif	// USE_SSL
121
122
123AbstractConnection::~AbstractConnection()
124{
125}
126
127
128// #pragma mark -
129
130
131ServerConnection::ServerConnection()
132	:
133	fConnection(NULL)
134{
135}
136
137
138ServerConnection::~ServerConnection()
139{
140	if (fConnection != NULL)
141		fConnection->Disconnect();
142	delete fConnection;
143}
144
145
146status_t
147ServerConnection::ConnectSSL(const char* server, uint32 port)
148{
149#ifdef USE_SSL
150	delete fConnection;
151	fConnection = new SSLConnection;
152	return fConnection->Connect(server, port);
153#else
154	return B_ERROR;
155#endif
156}
157
158
159status_t
160ServerConnection::ConnectSocket(const char* server, uint32 port)
161{
162	delete fConnection;
163	fConnection = new SocketConnection;
164	return fConnection->Connect(server, port);
165}
166
167
168status_t
169ServerConnection::Disconnect()
170{
171	if (fConnection == NULL)
172		return B_ERROR;
173	return fConnection->Disconnect();
174}
175
176
177status_t
178ServerConnection::WaitForData(bigtime_t timeout)
179{
180	if (fConnection == NULL)
181		return B_ERROR;
182	return fConnection->WaitForData(timeout);
183}
184
185
186ssize_t
187ServerConnection::Read(char* buffer, uint32 nBytes)
188{
189	if (fConnection == NULL)
190		return B_ERROR;
191	return fConnection->Read(buffer, nBytes);
192}
193
194
195ssize_t
196ServerConnection::Write(const char* buffer, uint32 nBytes)
197{
198	if (fConnection == NULL)
199		return B_ERROR;
200	return fConnection->Write(buffer, nBytes);
201}
202
203
204// #pragma mark -
205
206
207SocketConnection::SocketConnection()
208	:
209	fSocket(-1)
210{
211}
212
213
214status_t
215SocketConnection::Connect(const char* server, uint32 port)
216{
217	if (fSocket >= 0)
218		Disconnect();
219
220	TRACE("SocketConnection to server %s:%i\n", server, (int)port);
221
222	BNetworkAddress address;
223	status_t status = address.SetTo(server, port);
224	if (status != B_OK) {
225		TRACE("%s: Address Error: %s\n", __func__, strerror(status));
226		return status;
227	}
228
229	TRACE("Server resolves to %s\n", address.ToString().String());
230
231	fSocket = socket(address.Family(), SOCK_STREAM, 0);
232	if (fSocket < 0) {
233		TRACE("%s: Socket Error: %s\n", __func__, strerror(errno));
234		return errno;
235	}
236
237	int result = connect(fSocket, address, address.Length());
238	if (result < 0) {
239		TRACE("%s: Connect Error: %s\n", __func__, strerror(errno));
240		close(fSocket);
241		return errno;
242	}
243
244	TRACE("SocketConnection: connected\n");
245
246	return B_OK;
247}
248
249
250status_t
251SocketConnection::Disconnect()
252{
253	close(fSocket);
254	fSocket = -1;
255	return B_OK;
256}
257
258
259status_t
260SocketConnection::WaitForData(bigtime_t timeout)
261{
262	struct pollfd entry;
263	entry.fd = fSocket;
264	entry.events = POLLIN;
265
266	int timeoutMillis = -1;
267	if (timeout > 0)
268		timeoutMillis = timeout / 1000;
269
270	int result = poll(&entry, 1, timeoutMillis);
271	if (result == 0)
272		return B_TIMED_OUT;
273	if (result < 0)
274		return errno;
275
276	return B_OK;
277}
278
279
280ssize_t
281SocketConnection::Read(char* buffer, uint32 length)
282{
283	ssize_t bytesReceived = recv(fSocket, buffer, length, 0);
284	if (bytesReceived < 0)
285		return errno;
286
287	return bytesReceived;
288}
289
290
291ssize_t
292SocketConnection::Write(const char* buffer, uint32 length)
293{
294	ssize_t bytesWritten = send(fSocket, buffer, length, 0);
295	if (bytesWritten < 0)
296		return errno;
297
298	return bytesWritten;
299}
300
301
302// #pragma mark -
303
304
305#ifdef USE_SSL
306
307
308SSLConnection::SSLConnection()
309	:
310	fCTX(NULL),
311	fSSL(NULL),
312	fBIO(NULL)
313{
314}
315
316
317status_t
318SSLConnection::Connect(const char* server, uint32 port)
319{
320	if (fSSL != NULL)
321		Disconnect();
322
323	if (gInitSSL.InitCheck() != B_OK)
324		return B_ERROR;
325
326	status_t status = SocketConnection::Connect(server, port);
327	if (status != B_OK)
328		return status;
329
330	fCTX = SSL_CTX_new(SSLv23_method());
331	fSSL = SSL_new(fCTX);
332	fBIO = BIO_new_socket(fSocket, BIO_NOCLOSE);
333	SSL_set_bio(fSSL, fBIO, fBIO);
334
335	if (SSL_connect(fSSL) <= 0) {
336		TRACE("SSLConnection can't connect\n");
337		SocketConnection::Disconnect();
338    	return B_ERROR;
339	}
340
341	TRACE("SSLConnection connected\n");
342	return B_OK;
343}
344
345
346status_t
347SSLConnection::Disconnect()
348{
349	TRACE("SSLConnection::Disconnect()\n");
350
351	if (fSSL)
352		SSL_shutdown(fSSL);
353	if (fCTX)
354		SSL_CTX_free(fCTX);
355	if (fBIO)
356		BIO_free(fBIO);
357
358	fSSL = NULL;
359	fCTX = NULL;
360	fBIO = NULL;
361	return SocketConnection::Disconnect();
362}
363
364
365status_t
366SSLConnection::WaitForData(bigtime_t timeout)
367{
368	if (fSSL == NULL)
369		return B_NO_INIT;
370	if (SSL_pending(fSSL) > 0)
371		return B_OK;
372
373	return SocketConnection::WaitForData(timeout);
374}
375
376
377ssize_t
378SSLConnection::Read(char* buffer, uint32 length)
379{
380	if (fSSL == NULL)
381		return B_NO_INIT;
382
383	int bytesRead = SSL_read(fSSL, buffer, length);
384	if (bytesRead > 0)
385		return bytesRead;
386
387	// TODO: translate SSL error codes!
388	return B_ERROR;
389}
390
391
392ssize_t
393SSLConnection::Write(const char* buffer, uint32 length)
394{
395	if (fSSL == NULL)
396		return B_NO_INIT;
397
398	int bytesWritten = SSL_write(fSSL, buffer, length);
399	if (bytesWritten > 0)
400		return bytesWritten;
401
402	// TODO: translate SSL error codes!
403	return B_ERROR;
404}
405
406
407#endif	// USE_SSL
408
409
410}	// namespace BPrivate
411