1/* 2 * Copyright 2010-2011, Haiku, Inc. All rights reserved. 3 * Copyright 2010, Clemens Zeidler <haiku@clemens-zeidler.de> 4 * Distributed under the terms of the MIT License. 5 */ 6 7 8#include "ServerConnection.h" 9 10#include <errno.h> 11#include <sys/poll.h> 12#include <unistd.h> 13 14#ifdef USE_SSL 15# include <openssl/ssl.h> 16# include <openssl/rand.h> 17#endif 18 19#include <Autolock.h> 20#include <Locker.h> 21#include <NetworkAddress.h> 22 23 24#define DEBUG_SERVER_CONNECTION 25#ifdef DEBUG_SERVER_CONNECTION 26# include <stdio.h> 27# define TRACE(x...) printf(x) 28#else 29# define TRACE(x...) ; 30#endif 31 32 33namespace BPrivate { 34 35 36class AbstractConnection { 37public: 38 virtual ~AbstractConnection(); 39 40 virtual status_t Connect(const char* server, uint32 port) = 0; 41 virtual status_t Disconnect() = 0; 42 43 virtual status_t WaitForData(bigtime_t timeout) = 0; 44 45 virtual ssize_t Read(char* buffer, uint32 length) = 0; 46 virtual ssize_t Write(const char* buffer, uint32 length) = 0; 47}; 48 49 50class SocketConnection : public AbstractConnection { 51public: 52 SocketConnection(); 53 54 status_t Connect(const char* server, uint32 port); 55 status_t Disconnect(); 56 57 status_t WaitForData(bigtime_t timeout); 58 59 ssize_t Read(char* buffer, uint32 length); 60 ssize_t Write(const char* buffer, uint32 length); 61 62protected: 63 int fSocket; 64}; 65 66 67#ifdef USE_SSL 68 69 70class InitSSL { 71public: 72 InitSSL() 73 { 74 if (SSL_library_init() != 1) { 75 fInit = false; 76 return; 77 } 78 79 // use combination of more or less random this pointer and system time 80 int64 seed = (int64)this + system_time(); 81 RAND_seed(&seed, sizeof(seed)); 82 83 fInit = true; 84 return; 85 }; 86 87 88 status_t InitCheck() 89 { 90 return fInit ? B_OK : B_ERROR; 91 } 92 93private: 94 bool fInit; 95}; 96 97 98class SSLConnection : public SocketConnection { 99public: 100 SSLConnection(); 101 102 status_t Connect(const char* server, uint32 port); 103 status_t Disconnect(); 104 105 status_t WaitForData(bigtime_t timeout); 106 107 ssize_t Read(char* buffer, uint32 length); 108 ssize_t Write(const char* buffer, uint32 length); 109 110private: 111 SSL_CTX* fCTX; 112 SSL* fSSL; 113 BIO* fBIO; 114}; 115 116 117static InitSSL gInitSSL; 118 119 120#endif // USE_SSL 121 122 123AbstractConnection::~AbstractConnection() 124{ 125} 126 127 128// #pragma mark - 129 130 131ServerConnection::ServerConnection() 132 : 133 fConnection(NULL) 134{ 135} 136 137 138ServerConnection::~ServerConnection() 139{ 140 if (fConnection != NULL) 141 fConnection->Disconnect(); 142 delete fConnection; 143} 144 145 146status_t 147ServerConnection::ConnectSSL(const char* server, uint32 port) 148{ 149#ifdef USE_SSL 150 delete fConnection; 151 fConnection = new SSLConnection; 152 return fConnection->Connect(server, port); 153#else 154 return B_ERROR; 155#endif 156} 157 158 159status_t 160ServerConnection::ConnectSocket(const char* server, uint32 port) 161{ 162 delete fConnection; 163 fConnection = new SocketConnection; 164 return fConnection->Connect(server, port); 165} 166 167 168status_t 169ServerConnection::Disconnect() 170{ 171 if (fConnection == NULL) 172 return B_ERROR; 173 return fConnection->Disconnect(); 174} 175 176 177status_t 178ServerConnection::WaitForData(bigtime_t timeout) 179{ 180 if (fConnection == NULL) 181 return B_ERROR; 182 return fConnection->WaitForData(timeout); 183} 184 185 186ssize_t 187ServerConnection::Read(char* buffer, uint32 nBytes) 188{ 189 if (fConnection == NULL) 190 return B_ERROR; 191 return fConnection->Read(buffer, nBytes); 192} 193 194 195ssize_t 196ServerConnection::Write(const char* buffer, uint32 nBytes) 197{ 198 if (fConnection == NULL) 199 return B_ERROR; 200 return fConnection->Write(buffer, nBytes); 201} 202 203 204// #pragma mark - 205 206 207SocketConnection::SocketConnection() 208 : 209 fSocket(-1) 210{ 211} 212 213 214status_t 215SocketConnection::Connect(const char* server, uint32 port) 216{ 217 if (fSocket >= 0) 218 Disconnect(); 219 220 TRACE("SocketConnection to server %s:%i\n", server, (int)port); 221 222 BNetworkAddress address; 223 status_t status = address.SetTo(server, port); 224 if (status != B_OK) { 225 TRACE("%s: Address Error: %s\n", __func__, strerror(status)); 226 return status; 227 } 228 229 TRACE("Server resolves to %s\n", address.ToString().String()); 230 231 fSocket = socket(address.Family(), SOCK_STREAM, 0); 232 if (fSocket < 0) { 233 TRACE("%s: Socket Error: %s\n", __func__, strerror(errno)); 234 return errno; 235 } 236 237 int result = connect(fSocket, address, address.Length()); 238 if (result < 0) { 239 TRACE("%s: Connect Error: %s\n", __func__, strerror(errno)); 240 close(fSocket); 241 return errno; 242 } 243 244 TRACE("SocketConnection: connected\n"); 245 246 return B_OK; 247} 248 249 250status_t 251SocketConnection::Disconnect() 252{ 253 close(fSocket); 254 fSocket = -1; 255 return B_OK; 256} 257 258 259status_t 260SocketConnection::WaitForData(bigtime_t timeout) 261{ 262 struct pollfd entry; 263 entry.fd = fSocket; 264 entry.events = POLLIN; 265 266 int timeoutMillis = -1; 267 if (timeout > 0) 268 timeoutMillis = timeout / 1000; 269 270 int result = poll(&entry, 1, timeoutMillis); 271 if (result == 0) 272 return B_TIMED_OUT; 273 if (result < 0) 274 return errno; 275 276 return B_OK; 277} 278 279 280ssize_t 281SocketConnection::Read(char* buffer, uint32 length) 282{ 283 ssize_t bytesReceived = recv(fSocket, buffer, length, 0); 284 if (bytesReceived < 0) 285 return errno; 286 287 return bytesReceived; 288} 289 290 291ssize_t 292SocketConnection::Write(const char* buffer, uint32 length) 293{ 294 ssize_t bytesWritten = send(fSocket, buffer, length, 0); 295 if (bytesWritten < 0) 296 return errno; 297 298 return bytesWritten; 299} 300 301 302// #pragma mark - 303 304 305#ifdef USE_SSL 306 307 308SSLConnection::SSLConnection() 309 : 310 fCTX(NULL), 311 fSSL(NULL), 312 fBIO(NULL) 313{ 314} 315 316 317status_t 318SSLConnection::Connect(const char* server, uint32 port) 319{ 320 if (fSSL != NULL) 321 Disconnect(); 322 323 if (gInitSSL.InitCheck() != B_OK) 324 return B_ERROR; 325 326 status_t status = SocketConnection::Connect(server, port); 327 if (status != B_OK) 328 return status; 329 330 fCTX = SSL_CTX_new(SSLv23_method()); 331 fSSL = SSL_new(fCTX); 332 fBIO = BIO_new_socket(fSocket, BIO_NOCLOSE); 333 SSL_set_bio(fSSL, fBIO, fBIO); 334 335 if (SSL_connect(fSSL) <= 0) { 336 TRACE("SSLConnection can't connect\n"); 337 SocketConnection::Disconnect(); 338 return B_ERROR; 339 } 340 341 TRACE("SSLConnection connected\n"); 342 return B_OK; 343} 344 345 346status_t 347SSLConnection::Disconnect() 348{ 349 TRACE("SSLConnection::Disconnect()\n"); 350 351 if (fSSL) 352 SSL_shutdown(fSSL); 353 if (fCTX) 354 SSL_CTX_free(fCTX); 355 if (fBIO) 356 BIO_free(fBIO); 357 358 fSSL = NULL; 359 fCTX = NULL; 360 fBIO = NULL; 361 return SocketConnection::Disconnect(); 362} 363 364 365status_t 366SSLConnection::WaitForData(bigtime_t timeout) 367{ 368 if (fSSL == NULL) 369 return B_NO_INIT; 370 if (SSL_pending(fSSL) > 0) 371 return B_OK; 372 373 return SocketConnection::WaitForData(timeout); 374} 375 376 377ssize_t 378SSLConnection::Read(char* buffer, uint32 length) 379{ 380 if (fSSL == NULL) 381 return B_NO_INIT; 382 383 int bytesRead = SSL_read(fSSL, buffer, length); 384 if (bytesRead > 0) 385 return bytesRead; 386 387 // TODO: translate SSL error codes! 388 return B_ERROR; 389} 390 391 392ssize_t 393SSLConnection::Write(const char* buffer, uint32 length) 394{ 395 if (fSSL == NULL) 396 return B_NO_INIT; 397 398 int bytesWritten = SSL_write(fSSL, buffer, length); 399 if (bytesWritten > 0) 400 return bytesWritten; 401 402 // TODO: translate SSL error codes! 403 return B_ERROR; 404} 405 406 407#endif // USE_SSL 408 409 410} // namespace BPrivate 411