1/*
2 * Copyright 2011, Axel D��rfler, axeld@pinc-software.de.
3 * Copyright 2016, Rene Gollent, rene@gollent.com.
4 * Distributed under the terms of the MIT License.
5 */
6
7
8#include <AbstractSocket.h>
9
10#include <arpa/inet.h>
11#include <errno.h>
12#include <fcntl.h>
13#include <netinet/in.h>
14#include <sys/poll.h>
15#include <sys/time.h>
16
17
18//#define TRACE_SOCKET
19#ifdef TRACE_SOCKET
20#	define TRACE(x...) printf(x)
21#else
22#	define TRACE(x...) ;
23#endif
24
25
26BAbstractSocket::BAbstractSocket()
27	:
28	fInitStatus(B_NO_INIT),
29	fSocket(-1),
30	fIsBound(false),
31	fIsConnected(false),
32	fIsListening(false)
33{
34}
35
36
37BAbstractSocket::BAbstractSocket(const BAbstractSocket& other)
38	:
39	fInitStatus(other.fInitStatus),
40	fLocal(other.fLocal),
41	fPeer(other.fPeer),
42	fIsConnected(other.fIsConnected),
43	fIsListening(other.fIsListening)
44{
45	fSocket = dup(other.fSocket);
46	if (fSocket < 0)
47		fInitStatus = errno;
48}
49
50
51BAbstractSocket::~BAbstractSocket()
52{
53	Disconnect();
54}
55
56
57status_t
58BAbstractSocket::InitCheck() const
59{
60	return fInitStatus;
61}
62
63
64bool
65BAbstractSocket::IsBound() const
66{
67	return fIsBound;
68}
69
70
71bool
72BAbstractSocket::IsListening() const
73{
74	return fIsListening;
75}
76
77
78bool
79BAbstractSocket::IsConnected() const
80{
81	return fIsConnected;
82}
83
84
85status_t
86BAbstractSocket::Listen(int backlog)
87{
88	if (!fIsBound)
89		return B_NO_INIT;
90
91	if (listen(Socket(), backlog) != 0)
92		return fInitStatus = errno;
93
94	fIsListening = true;
95	return B_OK;
96}
97
98
99void
100BAbstractSocket::Disconnect()
101{
102	if (fSocket < 0)
103		return;
104
105	TRACE("%p: BAbstractSocket::Disconnect()\n", this);
106
107	close(fSocket);
108	fSocket = -1;
109	fIsConnected = false;
110	fIsBound = false;
111}
112
113
114status_t
115BAbstractSocket::SetTimeout(bigtime_t timeout)
116{
117	if (timeout < 0)
118		timeout = 0;
119
120	struct timeval tv;
121	tv.tv_sec = timeout / 1000000LL;
122	tv.tv_usec = timeout % 1000000LL;
123
124	if (setsockopt(fSocket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(timeval)) != 0
125		|| setsockopt(fSocket, SOL_SOCKET, SO_RCVTIMEO, &tv,
126			sizeof(timeval)) != 0) {
127		return errno;
128	}
129
130	return B_OK;
131}
132
133
134bigtime_t
135BAbstractSocket::Timeout() const
136{
137	struct timeval tv;
138	socklen_t size = sizeof(tv);
139	if (getsockopt(fSocket, SOL_SOCKET, SO_SNDTIMEO, &tv, &size) != 0)
140		return B_INFINITE_TIMEOUT;
141
142	return tv.tv_sec * 1000000LL + tv.tv_usec;
143}
144
145
146const BNetworkAddress&
147BAbstractSocket::Local() const
148{
149	return fLocal;
150}
151
152
153const BNetworkAddress&
154BAbstractSocket::Peer() const
155{
156	return fPeer;
157}
158
159
160size_t
161BAbstractSocket::MaxTransmissionSize() const
162{
163	return SSIZE_MAX;
164}
165
166
167status_t
168BAbstractSocket::WaitForReadable(bigtime_t timeout) const
169{
170	return _WaitFor(POLLIN, timeout);
171}
172
173
174status_t
175BAbstractSocket::WaitForWritable(bigtime_t timeout) const
176{
177	return _WaitFor(POLLOUT, timeout);
178}
179
180
181int
182BAbstractSocket::Socket() const
183{
184	return fSocket;
185}
186
187
188//	#pragma mark - protected
189
190
191status_t
192BAbstractSocket::Bind(const BNetworkAddress& local, bool reuseAddr, int type)
193{
194	fInitStatus = _OpenIfNeeded(local.Family(), type);
195	if (fInitStatus != B_OK)
196		return fInitStatus;
197
198	if (reuseAddr) {
199		int value = 1;
200		if (setsockopt(Socket(), SOL_SOCKET, SO_REUSEADDR, &value,
201				sizeof(value)) != 0) {
202			return fInitStatus = errno;
203		}
204	}
205
206	if (bind(fSocket, local, local.Length()) != 0)
207		return fInitStatus = errno;
208
209	fIsBound = true;
210	_UpdateLocalAddress();
211	return B_OK;
212}
213
214
215status_t
216BAbstractSocket::Connect(const BNetworkAddress& peer, int type,
217	bigtime_t timeout)
218{
219	Disconnect();
220
221	fInitStatus = _OpenIfNeeded(peer.Family(), type);
222	if (fInitStatus == B_OK)
223		fInitStatus = SetTimeout(timeout);
224
225	if (fInitStatus == B_OK && !IsBound()) {
226		// Bind to ADDR_ANY, if the address family supports it
227		BNetworkAddress local;
228		if (local.SetToWildcard(peer.Family()) == B_OK)
229			fInitStatus = Bind(local, true);
230	}
231	if (fInitStatus != B_OK)
232		return fInitStatus;
233
234	BNetworkAddress normalized = peer;
235	if (connect(fSocket, normalized, normalized.Length()) != 0) {
236		TRACE("%p: connecting to %s: %s\n", this,
237			normalized.ToString().c_str(), strerror(errno));
238		return fInitStatus = errno;
239	}
240
241	fIsConnected = true;
242	fPeer = normalized;
243	_UpdateLocalAddress();
244
245	TRACE("%p: connected to %s (local %s)\n", this, peer.ToString().c_str(),
246		fLocal.ToString().c_str());
247
248	return fInitStatus = B_OK;
249}
250
251
252status_t
253BAbstractSocket::AcceptNext(int& _acceptedSocket, BNetworkAddress& _peer)
254{
255	sockaddr_storage source;
256	socklen_t sourceLength = sizeof(sockaddr_storage);
257
258	int fd = accept(fSocket, (sockaddr*)&source, &sourceLength);
259	if (fd < 0)
260		return fd;
261
262	_peer.SetTo(source);
263	_acceptedSocket = fd;
264	return B_OK;
265}
266
267
268//	#pragma mark - private
269
270
271status_t
272BAbstractSocket::_OpenIfNeeded(int family, int type)
273{
274	if (fSocket >= 0)
275		return B_OK;
276
277	fSocket = socket(family, type, 0);
278	if (fSocket < 0)
279		return errno;
280
281	TRACE("%p: socket opened FD %d\n", this, fSocket);
282	return B_OK;
283}
284
285
286status_t
287BAbstractSocket::_UpdateLocalAddress()
288{
289	socklen_t localLength = sizeof(sockaddr_storage);
290	if (getsockname(fSocket, fLocal, &localLength) != 0)
291		return errno;
292
293	return B_OK;
294}
295
296
297status_t
298BAbstractSocket::_WaitFor(int flags, bigtime_t timeout) const
299{
300	if (fInitStatus != B_OK)
301		return fInitStatus;
302
303	int millis = 0;
304	if (timeout == B_INFINITE_TIMEOUT)
305		millis = -1;
306	if (timeout > 0)
307		millis = timeout / 1000;
308
309	struct pollfd entry;
310	entry.fd = Socket();
311	entry.events = flags;
312
313	int result;
314	do {
315		result = poll(&entry, 1, millis);
316	} while (result == -1 && errno == EINTR);
317	if (result < 0)
318		return errno;
319	if (result == 0)
320		return millis > 0 ? B_TIMED_OUT : B_WOULD_BLOCK;
321
322	return B_OK;
323}
324