1/*
2 * Copyright 2011, Axel Dörfler, axeld@pinc-software.de.
3 * Distributed under the terms of the MIT License.
4 */
5
6
7#include <AbstractSocket.h>
8
9#include <arpa/inet.h>
10#include <fcntl.h>
11#include <netinet/in.h>
12#include <sys/poll.h>
13
14
15//#define TRACE_SOCKET
16#ifdef TRACE_SOCKET
17#	define TRACE(x...) printf(x)
18#else
19#	define TRACE(x...) ;
20#endif
21
22
23BAbstractSocket::BAbstractSocket()
24	:
25	fInitStatus(B_NO_INIT),
26	fSocket(-1),
27	fIsBound(false),
28	fIsConnected(false)
29{
30}
31
32
33BAbstractSocket::BAbstractSocket(const BAbstractSocket& other)
34	:
35	fInitStatus(other.fInitStatus),
36	fLocal(other.fLocal),
37	fPeer(other.fPeer),
38	fIsConnected(other.fIsConnected)
39{
40	fSocket = dup(other.fSocket);
41	if (fSocket < 0)
42		fInitStatus = errno;
43}
44
45
46BAbstractSocket::~BAbstractSocket()
47{
48	Disconnect();
49}
50
51
52status_t
53BAbstractSocket::InitCheck() const
54{
55	return fInitStatus;
56}
57
58
59bool
60BAbstractSocket::IsBound() const
61{
62	return fIsBound;
63}
64
65
66bool
67BAbstractSocket::IsConnected() const
68{
69	return fIsConnected;
70}
71
72
73void
74BAbstractSocket::Disconnect()
75{
76	if (fSocket < 0)
77		return;
78
79	TRACE("%p: BAbstractSocket::Disconnect()\n", this);
80
81	close(fSocket);
82	fSocket = -1;
83	fIsConnected = false;
84	fIsBound = false;
85}
86
87
88status_t
89BAbstractSocket::SetTimeout(bigtime_t timeout)
90{
91	if (timeout < 0)
92		timeout = 0;
93
94	struct timeval tv;
95	tv.tv_sec = timeout / 1000000LL;
96	tv.tv_usec = timeout % 1000000LL;
97
98	if (setsockopt(fSocket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(timeval)) != 0
99		|| setsockopt(fSocket, SOL_SOCKET, SO_RCVTIMEO, &tv,
100				sizeof(timeval)) != 0)
101		return errno;
102
103	return B_OK;
104}
105
106
107bigtime_t
108BAbstractSocket::Timeout() const
109{
110	struct timeval tv;
111	socklen_t size = sizeof(tv);
112	if (getsockopt(fSocket, SOL_SOCKET, SO_SNDTIMEO, &tv, &size) != 0)
113		return B_INFINITE_TIMEOUT;
114
115	return tv.tv_sec * 1000000LL + tv.tv_usec;
116}
117
118
119const BNetworkAddress&
120BAbstractSocket::Local() const
121{
122	return fLocal;
123}
124
125
126const BNetworkAddress&
127BAbstractSocket::Peer() const
128{
129	return fPeer;
130}
131
132
133size_t
134BAbstractSocket::MaxTransmissionSize() const
135{
136	return SSIZE_MAX;
137}
138
139
140status_t
141BAbstractSocket::WaitForReadable(bigtime_t timeout) const
142{
143	return _WaitFor(POLLIN, timeout);
144}
145
146
147status_t
148BAbstractSocket::WaitForWritable(bigtime_t timeout) const
149{
150	return _WaitFor(POLLOUT, timeout);
151}
152
153
154int
155BAbstractSocket::Socket() const
156{
157	return fSocket;
158}
159
160
161//	#pragma mark - protected
162
163
164status_t
165BAbstractSocket::Bind(const BNetworkAddress& local, int type)
166{
167	fInitStatus = _OpenIfNeeded(local.Family(), type);
168	if (fInitStatus != B_OK)
169		return fInitStatus;
170
171	if (bind(fSocket, local, local.Length()) != 0)
172		return fInitStatus = errno;
173
174	fIsBound = true;
175	_UpdateLocalAddress();
176	return B_OK;
177}
178
179
180status_t
181BAbstractSocket::Connect(const BNetworkAddress& peer, int type,
182	bigtime_t timeout)
183{
184	Disconnect();
185
186	fInitStatus = _OpenIfNeeded(peer.Family(), type);
187	if (fInitStatus == B_OK)
188		fInitStatus = SetTimeout(timeout);
189
190	if (fInitStatus == B_OK && !IsBound()) {
191		BNetworkAddress local;
192		local.SetToWildcard(peer.Family());
193		fInitStatus = Bind(local);
194	}
195	if (fInitStatus != B_OK)
196		return fInitStatus;
197
198	BNetworkAddress normalized = peer;
199	if (connect(fSocket, normalized, normalized.Length()) != 0) {
200		TRACE("%p: connecting to %s: %s\n", this,
201			normalized.ToString().c_str(), strerror(errno));
202		return fInitStatus = errno;
203	}
204
205	fIsConnected = true;
206	fPeer = normalized;
207	_UpdateLocalAddress();
208
209	TRACE("%p: connected to %s (local %s)\n", this, peer.ToString().c_str(),
210		fLocal.ToString().c_str());
211
212	return fInitStatus = B_OK;
213}
214
215
216//	#pragma mark - private
217
218
219status_t
220BAbstractSocket::_OpenIfNeeded(int family, int type)
221{
222	if (fSocket >= 0)
223		return B_OK;
224
225	fSocket = socket(family, type, 0);
226	if (fSocket < 0)
227		return errno;
228
229	TRACE("%p: socket opened FD %d\n", this, fSocket);
230	return B_OK;
231}
232
233
234status_t
235BAbstractSocket::_UpdateLocalAddress()
236{
237	socklen_t localLength = sizeof(sockaddr_storage);
238	if (getsockname(fSocket, fLocal, &localLength) != 0)
239		return errno;
240
241	return B_OK;
242}
243
244
245status_t
246BAbstractSocket::_WaitFor(int flags, bigtime_t timeout) const
247{
248	if (fInitStatus != B_OK)
249		return fInitStatus;
250
251	int millis = 0;
252	if (timeout == B_INFINITE_TIMEOUT)
253		millis = -1;
254	if (timeout > 0)
255		millis = timeout / 1000;
256
257	struct pollfd entry;
258	entry.fd = Socket();
259	entry.events = flags;
260
261	int result = poll(&entry, 1, millis);
262	if (result < 0)
263		return errno;
264	if (result == 0)
265		return millis > 0 ? B_TIMED_OUT : B_WOULD_BLOCK;
266
267	return B_OK;
268}
269