1/*
2 * io_sock.c - SecureTransport sample I/O module, X sockets version
3 */
4
5#include "ioSock.h"
6#include <errno.h>
7#include <stdio.h>
8
9#include <unistd.h>
10#include <sys/types.h>
11#include <netinet/in.h>
12#include <sys/socket.h>
13#include <netdb.h>
14#include <arpa/inet.h>
15#include <fcntl.h>
16
17#include <CoreServices/../Frameworks/CarbonCore.framework/Headers/MacErrors.h>
18#include <time.h>
19#include <strings.h>
20
21/* debugging for this module */
22#define SSL_OT_DEBUG		1
23
24/* log errors to stdout */
25#define SSL_OT_ERRLOG		1
26
27/* trace all low-level network I/O */
28#define SSL_OT_IO_TRACE		0
29
30/* if SSL_OT_IO_TRACE, only log non-zero length transfers */
31#define SSL_OT_IO_TRACE_NZ	1
32
33/* pause after each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
34#define SSL_OT_IO_PAUSE		0
35
36/* print a stream of dots while I/O pending */
37#define SSL_OT_DOT			1
38
39/* dump some bytes of each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
40#define SSL_OT_IO_DUMP		0
41#define SSL_OT_IO_DUMP_SIZE	1024
42
43/* indicate errSSLWouldBlock with a '.' */
44#define SSL_DISPL_WOULD_BLOCK	0
45
46/* general, not-too-verbose debugging */
47#if		SSL_OT_DEBUG
48#define dprintf(s)	printf s
49#else
50#define dprintf(s)
51#endif
52
53/* errors --> stdout */
54#if		SSL_OT_ERRLOG
55#define eprintf(s)	printf s
56#else
57#define eprintf(s)
58#endif
59
60/* trace completion of every r/w */
61#if		SSL_OT_IO_TRACE
62static void tprintf(
63	const char *str,
64	UInt32 req,
65	UInt32 act,
66	const UInt8 *buf)
67{
68	#if	SSL_OT_IO_TRACE_NZ
69	if(act == 0) {
70		return;
71	}
72	#endif
73	printf("%s(%u): moved (%u) bytes\n", str, (unsigned)req, (unsigned)act);
74	#if	SSL_OT_IO_DUMP
75	{
76		unsigned i;
77
78		for(i=0; i<act; i++) {
79			printf("%02X ", buf[i]);
80			if(i >= (SSL_OT_IO_DUMP_SIZE - 1)) {
81				break;
82			}
83			if((i % 32) == 31) {
84				putchar('\n');
85			}
86		}
87		printf("\n");
88	}
89	#endif
90	#if SSL_OT_IO_PAUSE
91	{
92		char instr[20];
93		printf("CR to continue: ");
94		gets(instr);
95	}
96	#endif
97}
98
99#else
100#define tprintf(str, req, act, buf)
101#endif	/* SSL_OT_IO_TRACE */
102
103/*
104 * If SSL_OT_DOT, output a '.' every so often while waiting for
105 * connection. This gives user a chance to do something else with the
106 * UI.
107 */
108
109#if	SSL_OT_DOT
110
111static time_t lastTime = (time_t)0;
112#define TIME_INTERVAL		3
113
114static void outputDot()
115{
116	time_t thisTime = time(0);
117
118	if((thisTime - lastTime) >= TIME_INTERVAL) {
119		printf("."); fflush(stdout);
120		lastTime = thisTime;
121	}
122}
123#else
124#define outputDot()
125#endif
126
127
128/*
129 * One-time only init.
130 */
131void initSslOt()
132{
133
134}
135
136/*
137 * Connect to server.
138 */
139#define GETHOST_RETRIES		3
140
141OSStatus MakeServerConnection(
142	const char *hostName,
143	int port,
144	int nonBlocking,		// 0 or 1
145	otSocket *socketNo, 	// RETURNED
146	PeerSpec *peer)			// RETURNED
147{
148    struct sockaddr_in  addr;
149	struct hostent      *ent;
150    struct in_addr      host;
151	int					sock = 0;
152
153	*socketNo = NULL;
154    if (hostName[0] >= '0' && hostName[0] <= '9')
155    {
156        host.s_addr = inet_addr(hostName);
157    }
158    else {
159		unsigned dex;
160		/* seeing a lot of soft failures here that I really don't want to track down */
161		for(dex=0; dex<GETHOST_RETRIES; dex++) {
162			if(dex != 0) {
163				printf("\n...retrying gethostbyname(%s)", hostName);
164			}
165			ent = gethostbyname(hostName);
166			if(ent != NULL) {
167				break;
168			}
169		}
170        if(ent == NULL) {
171			printf("\n***gethostbyname(%s) returned: %s\n", hostName, hstrerror(h_errno));
172            return ioErr;
173        }
174        memcpy(&host, ent->h_addr, sizeof(struct in_addr));
175    }
176    sock = socket(AF_INET, SOCK_STREAM, 0);
177    addr.sin_addr = host;
178    addr.sin_port = htons((u_short)port);
179
180    addr.sin_family = AF_INET;
181    if (connect(sock, (struct sockaddr *) &addr, sizeof(struct sockaddr_in)) != 0)
182    {   printf("connect returned error\n");
183        return ioErr;
184    }
185
186	if(nonBlocking) {
187		/* OK to do this after connect? */
188		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
189		if(rtn == -1) {
190			perror("fctnl(O_NONBLOCK)");
191			return ioErr;
192		}
193	}
194
195    peer->ipAddr = addr.sin_addr.s_addr;
196    peer->port = htons((u_short)port);
197	*socketNo = (otSocket)sock;
198    return noErr;
199}
200
201/*
202 * Set up an otSocket to listen for client connections. Call once, then
203 * use multiple AcceptClientConnection calls.
204 */
205OSStatus ListenForClients(
206	int port,
207	int nonBlocking,		// 0 or 1
208	otSocket *socketNo) 	// RETURNED
209{
210	struct sockaddr_in  addr;
211    struct hostent      *ent;
212    int                 len;
213	int 				sock;
214
215    sock = socket(AF_INET, SOCK_STREAM, 0);
216	if(sock < 1) {
217		perror("socket");
218		return ioErr;
219	}
220
221    ent = gethostbyname("localhost");
222    if (!ent) {
223		perror("gethostbyname");
224		return ioErr;
225    }
226    memcpy(&addr.sin_addr, ent->h_addr, sizeof(struct in_addr));
227
228    addr.sin_port = htons((u_short)port);
229    addr.sin_addr.s_addr = INADDR_ANY;
230    addr.sin_family = AF_INET;
231    len = sizeof(struct sockaddr_in);
232    if (bind(sock, (struct sockaddr *) &addr, len)) {
233		int theErr = errno;
234		perror("bind");
235		if(theErr == EADDRINUSE) {
236			return opWrErr;
237		}
238		else {
239			return ioErr;
240		}
241    }
242	if(nonBlocking) {
243		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
244		if(rtn == -1) {
245			perror("fctnl(O_NONBLOCK)");
246			return ioErr;
247		}
248	}
249
250	for(;;) {
251		int rtn = listen(sock, 1);
252		switch(rtn) {
253			case 0:
254				*socketNo = (otSocket)sock;
255				rtn = noErr;
256				break;
257			case EWOULDBLOCK:
258				continue;
259			default:
260				perror("listen");
261				rtn = ioErr;
262				break;
263		}
264		return rtn;
265    }
266	/* NOT REACHED */
267	return 0;
268}
269
270/*
271 * Accept a client connection.
272 */
273
274/*
275 * Currently we always get back a different peer port number on successive
276 * connections, no matter what the client is doing. To test for resumable
277 * session support, force peer port = 0.
278 */
279#define FORCE_ACCEPT_PEER_PORT_ZERO		1
280
281OSStatus AcceptClientConnection(
282	otSocket listenSock, 		// obtained from ListenForClients
283	otSocket *acceptSock, 		// RETURNED
284	PeerSpec *peer)				// RETURNED
285{
286	struct sockaddr_in  addr;
287	int					sock;
288    socklen_t           len;
289
290    len = sizeof(struct sockaddr_in);
291	do {
292		sock = accept((int)listenSock, (struct sockaddr *) &addr, &len);
293		if (sock < 0) {
294			if(errno == EAGAIN) {
295				/* nonblocking, no connection yet */
296				continue;
297			}
298			else {
299				perror("accept");
300				return ioErr;
301			}
302		}
303		else {
304			break;
305		}
306    } while(1);
307	*acceptSock = (otSocket)sock;
308    peer->ipAddr = addr.sin_addr.s_addr;
309	#if	FORCE_ACCEPT_PEER_PORT_ZERO
310	peer->port = 0;
311	#else
312    peer->port = ntohs(addr.sin_port);
313	#endif
314    return noErr;
315}
316
317/*
318 * Shut down a connection.
319 */
320void endpointShutdown(
321	otSocket socket)
322{
323	close((int)socket);
324}
325
326/*
327 * R/W. Called out from SSL.
328 */
329OSStatus SocketRead(
330	SSLConnectionRef 	connection,
331	void 				*data, 			/* owned by
332	 									 * caller, data
333	 									 * RETURNED */
334	size_t 				*dataLength)	/* IN/OUT */
335{
336	UInt32			bytesToGo = *dataLength;
337	UInt32 			initLen = bytesToGo;
338	UInt8			*currData = (UInt8 *)data;
339	int				sock = (int)((long)connection);
340	OSStatus		rtn = noErr;
341	UInt32			bytesRead;
342	ssize_t			rrtn;
343
344	*dataLength = 0;
345
346	for(;;) {
347		bytesRead = 0;
348		/* paranoid check, ensure errno is getting written */
349		errno = -555;
350		rrtn = recv(sock, currData, bytesToGo, 0);
351		if (rrtn <= 0) {
352			if(rrtn == 0) {
353				/* closed, EOF */
354				rtn = errSSLClosedGraceful;
355				break;
356			}
357			int theErr = errno;
358			switch(theErr) {
359				case ENOENT:
360					/*
361					 * Undocumented but I definitely see this.
362					 * Non-blocking sockets only. Definitely retriable
363					 * just like an EAGAIN.
364					 */
365					dprintf(("SocketRead RETRYING on ENOENT, rrtn %d\n",
366						(int)rrtn));
367					/* normal... */
368					//rtn = errSSLWouldBlock;
369					/* ...for temp testing.... */
370					rtn = ioErr;
371					break;
372				case ECONNRESET:
373					/* explicit peer abort */
374					rtn = errSSLClosedAbort;
375					break;
376				case EAGAIN:
377					/* nonblocking, no data */
378					rtn = errSSLWouldBlock;
379					break;
380				default:
381					dprintf(("SocketRead: read(%u) error %d, rrtn %d\n",
382						(unsigned)bytesToGo, theErr, (int)rrtn));
383					rtn = ioErr;
384					break;
385			}
386			/* in any case, we're done with this call if rrtn <= 0 */
387			break;
388		}
389		bytesRead = rrtn;
390		bytesToGo -= bytesRead;
391		currData  += bytesRead;
392
393		if(bytesToGo == 0) {
394			/* filled buffer with incoming data, done */
395			break;
396		}
397	}
398	*dataLength = initLen - bytesToGo;
399	tprintf("SocketRead", initLen, *dataLength, (UInt8 *)data);
400
401	#if SSL_OT_DOT || (SSL_OT_DEBUG && !SSL_OT_IO_TRACE)
402	if((rtn == 0) && (*dataLength == 0)) {
403		/* keep UI alive */
404		outputDot();
405	}
406	#endif
407	#if SSL_DISPL_WOULD_BLOCK
408	if(rtn == errSSLWouldBlock) {
409		printf("."); fflush(stdout);
410	}
411	#endif
412	return rtn;
413}
414
415int oneAtATime = 0;
416
417OSStatus SocketWrite(
418	SSLConnectionRef 	connection,
419	const void	 		*data,
420	size_t 				*dataLength)	/* IN/OUT */
421{
422	size_t		bytesSent = 0;
423	int			sock = (int)((long)connection);
424	int 		length;
425	size_t		dataLen = *dataLength;
426	const UInt8 *dataPtr = (UInt8 *)data;
427	OSStatus	ortn;
428
429	if(oneAtATime && (*dataLength > 1)) {
430		size_t i;
431		size_t outLen;
432		size_t thisMove;
433
434		outLen = 0;
435		for(i=0; i<dataLen; i++) {
436			thisMove = 1;
437			ortn = SocketWrite(connection, dataPtr, &thisMove);
438			outLen += thisMove;
439			dataPtr++;
440			if(ortn) {
441				return ortn;
442			}
443		}
444		return noErr;
445	}
446	*dataLength = 0;
447
448    do {
449        length = write(sock,
450				(char*)dataPtr + bytesSent,
451				dataLen - bytesSent);
452    } while ((length > 0) &&
453			 ( (bytesSent += length) < dataLen) );
454
455	if(length <= 0) {
456		int theErr = errno;
457		switch(theErr) {
458			case EAGAIN:
459				ortn = errSSLWouldBlock; break;
460			case EPIPE:
461			/* as of Leopard 9A312 or so, the error formerly seen as EPIPE is
462			 * now reported as ECONNRESET. This happens when we're catching
463			 * SIGPIPE and we write to a socket which has been closed by the peer.
464			 */
465			case ECONNRESET:
466				ortn = errSSLClosedAbort; break;
467			default:
468				dprintf(("SocketWrite: write(%u) error %d\n",
469					  (unsigned)(dataLen - bytesSent), theErr));
470				ortn = ioErr;
471				break;
472		}
473	}
474	else {
475		ortn = noErr;
476	}
477	tprintf("SocketWrite", dataLen, bytesSent, dataPtr);
478	*dataLength = bytesSent;
479	return ortn;
480}
481