1// InsecureConnection.cpp
2
3// TODO: Asynchronous connecting on client side?
4
5#include <new>
6
7#include <errno.h>
8#include <netdb.h>
9#include <stdio.h>
10#include <string.h>
11#include <unistd.h>
12
13#include <ByteOrder.h>
14
15#include "Compatibility.h"
16#include "DebugSupport.h"
17#include "InsecureChannel.h"
18#include "InsecureConnection.h"
19#include "NetAddress.h"
20#include "NetFSDefs.h"
21
22namespace InsecureConnectionDefs {
23
24const int32 kProtocolVersion = 1;
25
26const bigtime_t kAcceptingTimeout = 10000000;	// 10 s
27
28// number of client up/down stream channels
29const int32 kMinUpStreamChannels		= 1;
30const int32 kMaxUpStreamChannels		= 10;
31const int32 kDefaultUpStreamChannels	= 5;
32const int32 kMinDownStreamChannels		= 1;
33const int32 kMaxDownStreamChannels		= 5;
34const int32 kDefaultDownStreamChannels	= 1;
35
36} // namespace InsecureConnectionDefs
37
38using namespace InsecureConnectionDefs;
39
40// SocketCloser
41struct SocketCloser {
42	SocketCloser(int fd) : fFD(fd) {}
43	~SocketCloser()
44	{
45		if (fFD >= 0)
46			closesocket(fFD);
47	}
48
49	int Detach()
50	{
51		int fd = fFD;
52		fFD = -1;
53		return fd;
54	}
55
56private:
57	int	fFD;
58};
59
60
61// #pragma mark -
62// #pragma mark ----- InsecureConnection -----
63
64// constructor
65InsecureConnection::InsecureConnection()
66{
67}
68
69// destructor
70InsecureConnection::~InsecureConnection()
71{
72}
73
74// Init (server side)
75status_t
76InsecureConnection::Init(int fd)
77{
78	status_t error = AbstractConnection::Init();
79	if (error != B_OK) {
80		closesocket(fd);
81		return error;
82	}
83	// create the initial channel
84	Channel* channel = new(std::nothrow) InsecureChannel(fd);
85	if (!channel) {
86		closesocket(fd);
87		return B_NO_MEMORY;
88	}
89	// add it
90	error = AddDownStreamChannel(channel);
91	if (error != B_OK) {
92		delete channel;
93		return error;
94	}
95	return B_OK;
96}
97
98// Init (client side)
99status_t
100InsecureConnection::Init(const char* parameters)
101{
102PRINT(("InsecureConnection::Init\n"));
103	if (!parameters)
104		return B_BAD_VALUE;
105	status_t error = AbstractConnection::Init();
106	if (error != B_OK)
107		return error;
108	// parse the parameters to get a server name and a port we shall connect to
109	// parameter format is "<server>[:port] [ <up> [ <down> ] ]"
110	char server[256];
111	uint16 port = kDefaultInsecureConnectionPort;
112	int upStreamChannels = kDefaultUpStreamChannels;
113	int downStreamChannels = kDefaultDownStreamChannels;
114	if (strchr(parameters, ':')) {
115		int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port,
116			&upStreamChannels, &downStreamChannels);
117		if (result < 2)
118			return B_BAD_VALUE;
119	} else {
120		int result = sscanf(parameters, "%255[^:] %d %d", server,
121			&upStreamChannels, &downStreamChannels);
122		if (result < 1)
123			return B_BAD_VALUE;
124	}
125	// resolve server address
126	NetAddress netAddress;
127	error = NetAddressResolver().GetHostAddress(server, &netAddress);
128	if (error != B_OK)
129		return error;
130	in_addr serverAddr = netAddress.GetAddress().sin_addr;
131	// open the initial channel
132	Channel* channel;
133	error = _OpenClientChannel(serverAddr, port, &channel);
134	if (error != B_OK)
135		return error;
136	error = AddUpStreamChannel(channel);
137	if (error != B_OK) {
138		delete channel;
139		return error;
140	}
141	// send the server a connect request
142	ConnectRequest request;
143	request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion);
144	request.serverAddress = serverAddr.s_addr;
145	request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
146	request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
147	error = channel->Send(&request, sizeof(ConnectRequest));
148	if (error != B_OK)
149		return error;
150	// get the server reply
151	ConnectReply reply;
152	error = channel->Receive(&reply, sizeof(ConnectReply));
153	if (error != B_OK)
154		return error;
155	error = B_BENDIAN_TO_HOST_INT32(reply.error);
156	if (error != B_OK)
157		return error;
158	upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels);
159	downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels);
160	port = B_BENDIAN_TO_HOST_INT16(reply.port);
161	// open the remaining channels
162	int32 allChannels = upStreamChannels + downStreamChannels;
163	for (int32 i = 1; i < allChannels; i++) {
164		PRINT("  creating channel %" B_PRId32 "\n", i);
165		// open the channel
166		error = _OpenClientChannel(serverAddr, port, &channel);
167		if (error != B_OK)
168			RETURN_ERROR(error);
169		// add it
170		if (i < upStreamChannels)
171			error = AddUpStreamChannel(channel);
172		else
173			error = AddDownStreamChannel(channel);
174		if (error != B_OK) {
175			delete channel;
176			return error;
177		}
178	}
179	return B_OK;
180}
181
182// FinishInitialization
183status_t
184InsecureConnection::FinishInitialization()
185{
186PRINT(("InsecureConnection::FinishInitialization()\n"));
187	// get the down stream channel
188	InsecureChannel* channel
189		= dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0));
190	if (!channel)
191		return B_BAD_VALUE;
192	// receive the connect request
193	ConnectRequest request;
194	status_t error = channel->Receive(&request, sizeof(ConnectRequest));
195	if (error != B_OK)
196		return error;
197	// check the protocol version
198	int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion);
199	if (protocolVersion != kProtocolVersion) {
200		_SendErrorReply(channel, B_ERROR);
201		return B_ERROR;
202	}
203	// get our address (we need it for binding)
204	in_addr serverAddr;
205	serverAddr.s_addr = request.serverAddress;
206	// check number of up and down stream channels
207	int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels);
208	int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32(
209		request.downStreamChannels);
210	if (upStreamChannels < kMinUpStreamChannels)
211		upStreamChannels = kMinUpStreamChannels;
212	else if (upStreamChannels > kMaxUpStreamChannels)
213		upStreamChannels = kMaxUpStreamChannels;
214	if (downStreamChannels < kMinDownStreamChannels)
215		downStreamChannels = kMinDownStreamChannels;
216	else if (downStreamChannels > kMaxDownStreamChannels)
217		downStreamChannels = kMaxDownStreamChannels;
218	// due to a bug on BONE we have a maximum of 2 working connections
219	// accepted on one listener socket.
220	NetAddress peerAddress;
221	if (channel->GetPeerAddress(&peerAddress) == B_OK
222		&& peerAddress.IsLocal()) {
223		upStreamChannels = 1;
224		if (downStreamChannels > 2)
225			downStreamChannels = 2;
226	}
227	int32 allChannels = upStreamChannels + downStreamChannels;
228	// create a listener socket
229	int fd = socket(AF_INET, SOCK_STREAM, 0);
230	if (fd < 0) {
231		error = errno;
232		_SendErrorReply(channel, error);
233		return error;
234	}
235	SocketCloser _(fd);
236	// bind it to some port
237	sockaddr_in addr;
238	addr.sin_family = AF_INET;
239	addr.sin_port = 0;
240	addr.sin_addr = serverAddr;
241	if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
242		error = errno;
243		_SendErrorReply(channel, error);
244		return error;
245	}
246	// get the port
247	socklen_t addrSize = sizeof(addr);
248	if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) {
249		error = errno;
250		_SendErrorReply(channel, error);
251		return error;
252	}
253	// set socket to non-blocking
254	int dontBlock = 1;
255	if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) {
256		error = errno;
257		_SendErrorReply(channel, error);
258		return error;
259	}
260	// start listening
261	if (listen(fd, allChannels - 1) < 0) {
262		error = errno;
263		_SendErrorReply(channel, error);
264		return error;
265	}
266	// send the reply
267	ConnectReply reply;
268	reply.error = B_HOST_TO_BENDIAN_INT32(B_OK);
269	reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
270	reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
271	reply.port = addr.sin_port;
272	error = channel->Send(&reply, sizeof(ConnectReply));
273	if (error != B_OK)
274		return error;
275	// start accepting
276	bigtime_t startAccepting = system_time();
277	for (int32 i = 1; i < allChannels; ) {
278		// accept a connection
279		int channelFD = accept(fd, NULL, 0);
280		if (channelFD < 0) {
281			error = errno;
282			if (error == B_INTERRUPTED) {
283				error = B_OK;
284				continue;
285			}
286			if (error == B_WOULD_BLOCK) {
287				bigtime_t now = system_time();
288				if (now - startAccepting > kAcceptingTimeout)
289					RETURN_ERROR(B_TIMED_OUT);
290				snooze(10000);
291				continue;
292			}
293			RETURN_ERROR(error);
294		}
295		PRINT("  accepting channel %" B_PRId32 "\n", i);
296		// create a channel
297		channel = new(std::nothrow) InsecureChannel(channelFD);
298		if (!channel) {
299			closesocket(channelFD);
300			return B_NO_MEMORY;
301		}
302		// add it
303		if (i < upStreamChannels)	// inverse, since we are on server side
304			error = AddDownStreamChannel(channel);
305		else
306			error = AddUpStreamChannel(channel);
307		if (error != B_OK) {
308			delete channel;
309			return error;
310		}
311		i++;
312		startAccepting = system_time();
313	}
314	return B_OK;
315}
316
317// _OpenClientChannel
318status_t
319InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port,
320	Channel** _channel)
321{
322	// create a socket
323	int fd = socket(AF_INET, SOCK_STREAM, 0);
324	if (fd < 0)
325		return errno;
326	// connect
327	sockaddr_in addr;
328	addr.sin_family = AF_INET;
329	addr.sin_port = htons(port);
330	addr.sin_addr = serverAddr;
331	if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
332		status_t error = errno;
333		closesocket(fd);
334		RETURN_ERROR(error);
335	}
336	// create the channel
337	Channel* channel = new(std::nothrow) InsecureChannel(fd);
338	if (!channel) {
339		closesocket(fd);
340		return B_NO_MEMORY;
341	}
342	*_channel = channel;
343	return B_OK;
344}
345
346// _SendErrorReply
347status_t
348InsecureConnection::_SendErrorReply(Channel* channel, status_t error)
349{
350	ConnectReply reply;
351	reply.error = B_HOST_TO_BENDIAN_INT32(error);
352	return channel->Send(&reply, sizeof(ConnectReply));
353}
354
355