1/*
2 * Copyright (c) 2006-2008,2010,2012-2013 Apple Inc. All Rights Reserved.
3 *
4 * io_sock.c - SecureTransport sample I/O module, X sockets version
5 */
6
7#include "ioSock.h"
8#include <errno.h>
9#include <stdio.h>
10#include <Security/SecBase.h>
11#include <unistd.h>
12#include <sys/types.h>
13#include <netinet/in.h>
14#include <sys/socket.h>
15#include <netdb.h>
16#include <arpa/inet.h>
17#include <fcntl.h>
18
19#include <Security/SecBase.h>
20
21#include <time.h>
22#include <strings.h>
23
24/* debugging for this module */
25#define SSL_OT_DEBUG		1
26
27/* log errors to stdout */
28#define SSL_OT_ERRLOG		1
29
30/* trace all low-level network I/O */
31#define SSL_OT_IO_TRACE		0
32
33/* if SSL_OT_IO_TRACE, only log non-zero length transfers */
34#define SSL_OT_IO_TRACE_NZ	1
35
36/* pause after each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
37#define SSL_OT_IO_PAUSE		0
38
39/* print a stream of dots while I/O pending */
40#define SSL_OT_DOT			1
41
42/* dump some bytes of each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
43#define SSL_OT_IO_DUMP		0
44#define SSL_OT_IO_DUMP_SIZE	256
45
46/* indicate errSSLWouldBlock with a '.' */
47#define SSL_DISPL_WOULD_BLOCK	0
48
49/* general, not-too-verbose debugging */
50#if		SSL_OT_DEBUG
51#define dprintf(s)	printf s
52#else
53#define dprintf(s)
54#endif
55
56/* errors --> stdout */
57#if		SSL_OT_ERRLOG
58#define eprintf(s)	printf s
59#else
60#define eprintf(s)
61#endif
62
63/* trace completion of every r/w */
64#if		SSL_OT_IO_TRACE
65static void tprintf(
66	const char *str,
67	UInt32 req,
68	UInt32 act,
69	const UInt8 *buf)
70{
71	#if	SSL_OT_IO_TRACE_NZ
72	if(act == 0) {
73		return;
74	}
75	#endif
76	printf("%s(%u): moved (%u) bytes\n", str, (unsigned)req, (unsigned)act);
77	#if	SSL_OT_IO_DUMP
78	{
79		unsigned i;
80
81		for(i=0; i<act; i++) {
82			printf("%02X ", buf[i]);
83			if(i >= (SSL_OT_IO_DUMP_SIZE - 1)) {
84				break;
85			}
86		}
87		printf("\n");
88	}
89	#endif
90	#if SSL_OT_IO_PAUSE
91	{
92		char instr[20];
93		printf("CR to continue: ");
94		gets(instr);
95	}
96	#endif
97}
98
99#else
100#define tprintf(str, req, act, buf)
101#endif	/* SSL_OT_IO_TRACE */
102
103/*
104 * If SSL_OT_DOT, output a '.' every so often while waiting for
105 * connection. This gives user a chance to do something else with the
106 * UI.
107 */
108
109#if	SSL_OT_DOT
110
111static time_t lastTime = (time_t)0;
112#define TIME_INTERVAL		3
113
114static void outputDot()
115{
116	time_t thisTime = time(0);
117
118	if((thisTime - lastTime) >= TIME_INTERVAL) {
119		printf("."); fflush(stdout);
120		lastTime = thisTime;
121	}
122}
123#else
124#define outputDot()
125#endif
126
127
128/*
129 * One-time only init.
130 */
131void initSslOt(void)
132{
133
134}
135
136/*
137 * Connect to server.
138 */
139#define GETHOST_RETRIES		3
140
141OSStatus MakeServerConnection(
142	const char *hostName,
143	int port,
144	int nonBlocking,		// 0 or 1
145	otSocket *socketNo, 	// RETURNED
146	PeerSpec *peer)			// RETURNED
147{
148    struct sockaddr_in  addr;
149	struct hostent      *ent;
150    struct in_addr      host;
151	int					sock = 0;
152
153	*socketNo = 0;
154    if (hostName[0] >= '0' && hostName[0] <= '9')
155    {
156        host.s_addr = inet_addr(hostName);
157    }
158    else {
159		unsigned dex;
160		/* seeing a lot of soft failures here that I really don't want to track down */
161		for(dex=0; dex<GETHOST_RETRIES; dex++) {
162			if(dex != 0) {
163				printf("\n...retrying gethostbyname(%s)", hostName);
164			}
165			ent = gethostbyname(hostName);
166			if(ent != NULL) {
167				break;
168			}
169		}
170        if(ent == NULL) {
171			printf("\n***gethostbyname(%s) returned: %s\n", hostName, hstrerror(h_errno));
172            return errSecIO;
173        }
174        memcpy(&host, ent->h_addr, sizeof(struct in_addr));
175    }
176    sock = socket(AF_INET, SOCK_STREAM, 0);
177    addr.sin_addr = host;
178    addr.sin_port = htons((u_short)port);
179
180    addr.sin_family = AF_INET;
181    if (connect(sock, (struct sockaddr *) &addr, sizeof(struct sockaddr_in)) != 0)
182    {   printf("connect returned error\n");
183        return errSecIO;
184    }
185
186	if(nonBlocking) {
187		/* OK to do this after connect? */
188		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
189		if(rtn == -1) {
190			perror("fctnl(O_NONBLOCK)");
191			return errSecIO;
192		}
193	}
194
195    peer->ipAddr = addr.sin_addr.s_addr;
196    peer->port = htons((u_short)port);
197	*socketNo = (otSocket)sock;
198    return errSecSuccess;
199}
200
201/*
202 * Set up an otSocket to listen for client connections. Call once, then
203 * use multiple AcceptClientConnection calls.
204 */
205OSStatus ListenForClients(
206	int port,
207	int nonBlocking,		// 0 or 1
208	otSocket *socketNo) 	// RETURNED
209{
210	struct sockaddr_in  addr;
211    struct hostent      *ent;
212    int                 len;
213	int 				sock;
214
215    sock = socket(AF_INET, SOCK_STREAM, 0);
216	if(sock < 1) {
217		perror("socket");
218		return errSecIO;
219	}
220
221    int reuse = 1;
222    int err = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
223    if (err != 0) {
224        perror("setsockopt");
225        return err;
226    }
227
228    ent = gethostbyname("localhost");
229    if (!ent) {
230		perror("gethostbyname");
231		return errSecIO;
232    }
233    memcpy(&addr.sin_addr, ent->h_addr, sizeof(struct in_addr));
234
235    addr.sin_port = htons((u_short)port);
236    addr.sin_addr.s_addr = INADDR_ANY;
237    addr.sin_family = AF_INET;
238    len = sizeof(struct sockaddr_in);
239    if (bind(sock, (struct sockaddr *) &addr, len)) {
240		int theErr = errno;
241		perror("bind");
242		if(theErr == EADDRINUSE) {
243			return errSecOpWr;
244		}
245		else {
246			return errSecIO;
247		}
248    }
249	if(nonBlocking) {
250		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
251		if(rtn == -1) {
252			perror("fctnl(O_NONBLOCK)");
253			return errSecIO;
254		}
255	}
256
257	for(;;) {
258		int rtn = listen(sock, 1);
259		switch(rtn) {
260			case 0:
261				*socketNo = (otSocket)sock;
262				rtn = errSecSuccess;
263				break;
264			case EWOULDBLOCK:
265				continue;
266			default:
267				perror("listen");
268				rtn = errSecIO;
269				break;
270		}
271		return rtn;
272    }
273	/* NOT REACHED */
274	return 0;
275}
276
277/*
278 * Accept a client connection.
279 */
280
281/*
282 * Currently we always get back a different peer port number on successive
283 * connections, no matter what the client is doing. To test for resumable
284 * session support, force peer port = 0.
285 */
286#define FORCE_ACCEPT_PEER_PORT_ZERO		1
287
288OSStatus AcceptClientConnection(
289	otSocket listenSock, 		// obtained from ListenForClients
290	otSocket *acceptSock, 		// RETURNED
291	PeerSpec *peer)				// RETURNED
292{
293	struct sockaddr_in  addr;
294	int					sock;
295    socklen_t           len;
296
297    len = sizeof(struct sockaddr_in);
298	do {
299		sock = accept((int)listenSock, (struct sockaddr *) &addr, &len);
300		if (sock < 0) {
301			if(errno == EAGAIN) {
302				/* nonblocking, no connection yet */
303				continue;
304			}
305			else {
306				perror("accept");
307				return errSecIO;
308			}
309		}
310		else {
311			break;
312		}
313    } while(1);
314	*acceptSock = (otSocket)sock;
315    peer->ipAddr = addr.sin_addr.s_addr;
316	#if	FORCE_ACCEPT_PEER_PORT_ZERO
317	peer->port = 0;
318	#else
319    peer->port = ntohs(addr.sin_port);
320	#endif
321    return errSecSuccess;
322}
323
324/*
325 * Shut down a connection.
326 */
327void endpointShutdown(
328	otSocket sock)
329{
330	close((int)sock);
331}
332
333/*
334 * R/W. Called out from SSL.
335 */
336OSStatus SocketRead(
337	SSLConnectionRef 	connection,
338	void 				*data, 			/* owned by
339	 									 * caller, data
340	 									 * RETURNED */
341	size_t 				*dataLength)	/* IN/OUT */
342{
343	size_t			bytesToGo = *dataLength;
344	size_t 			initLen = bytesToGo;
345	UInt8			*currData = (UInt8 *)data;
346	int				sock = (int)((long)connection);
347	OSStatus		rtn = errSecSuccess;
348	size_t			bytesRead = 0;
349	ssize_t			rrtn;
350
351	*dataLength = 0;
352
353	for(;;) {
354		/* paranoid check, ensure errno is getting written */
355		errno = -555;
356		rrtn = recv(sock, currData, bytesToGo, 0);
357		if (rrtn <= 0) {
358			if(rrtn == 0) {
359				/* closed, EOF */
360				rtn = errSSLClosedGraceful;
361				break;
362			}
363			int theErr = errno;
364			switch(theErr) {
365				case ENOENT:
366					/*
367					 * Undocumented but I definitely see this.
368					 * Non-blocking sockets only. Definitely retriable
369					 * just like an EAGAIN.
370					 */
371					dprintf(("SocketRead RETRYING on ENOENT, rrtn %d\n",
372						(int)rrtn));
373					/* normal... */
374					//rtn = errSSLWouldBlock;
375					/* ...for temp testing.... */
376					rtn = errSecIO;
377					break;
378				case ECONNRESET:
379					/* explicit peer abort */
380					rtn = errSSLClosedAbort;
381					break;
382				case EAGAIN:
383					/* nonblocking, no data */
384					rtn = errSSLWouldBlock;
385					break;
386				default:
387					dprintf(("SocketRead: read(%u) error %d, rrtn %d\n",
388						(unsigned)bytesToGo, theErr, (int)rrtn));
389					rtn = errSecIO;
390					break;
391			}
392			/* in any case, we're done with this call if rrtn <= 0 */
393			break;
394		}
395		bytesRead = rrtn;
396		bytesToGo -= bytesRead;
397		currData  += bytesRead;
398
399		if(bytesToGo == 0) {
400			/* filled buffer with incoming data, done */
401			break;
402		}
403	}
404	*dataLength = initLen - bytesToGo;
405	tprintf("SocketRead", initLen, *dataLength, (UInt8 *)data);
406
407	#if SSL_OT_DOT || (SSL_OT_DEBUG && !SSL_OT_IO_TRACE)
408	if((rtn == 0) && (*dataLength == 0)) {
409		/* keep UI alive */
410		outputDot();
411	}
412	#endif
413	#if SSL_DISPL_WOULD_BLOCK
414	if(rtn == errSSLWouldBlock) {
415		printf("."); fflush(stdout);
416	}
417	#endif
418	return rtn;
419}
420
421int oneAtATime = 0;
422
423OSStatus SocketWrite(
424	SSLConnectionRef 	connection,
425	const void	 		*data,
426	size_t 				*dataLength)	/* IN/OUT */
427{
428	size_t		bytesSent = 0;
429	int			sock = (int)((long)connection);
430	size_t 		length;
431	size_t		dataLen = *dataLength;
432	const UInt8 *dataPtr = (UInt8 *)data;
433	OSStatus	ortn;
434
435	if(oneAtATime && (*dataLength > 1)) {
436		size_t i;
437		size_t outLen;
438		size_t thisMove;
439
440		outLen = 0;
441		for(i=0; i<dataLen; i++) {
442			thisMove = 1;
443			ortn = SocketWrite(connection, dataPtr, &thisMove);
444			outLen += thisMove;
445			dataPtr++;
446			if(ortn) {
447				return ortn;
448			}
449		}
450		return errSecSuccess;
451	}
452	*dataLength = 0;
453
454    do {
455        length = write(sock,
456				(char*)dataPtr + bytesSent,
457				dataLen - bytesSent);
458    } while ((length > 0) &&
459			 ( (bytesSent += length) < dataLen) );
460
461	if(length <= 0) {
462		int theErr = errno;
463		switch(theErr) {
464			case EAGAIN:
465				ortn = errSSLWouldBlock; break;
466			case EPIPE:
467				ortn = errSSLClosedAbort; break;
468			default:
469				dprintf(("SocketWrite: write(%u) error %d\n",
470					  (unsigned)(dataLen - bytesSent), theErr));
471				ortn = errSecIO;
472				break;
473		}
474	}
475	else {
476		ortn = errSecSuccess;
477	}
478	tprintf("SocketWrite", dataLen, bytesSent, dataPtr);
479	*dataLength = bytesSent;
480	return ortn;
481}
482