1/*
2 * Copyright (c) 2006-2008,2010-2012 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24/*
25 * ioSock.c - socket-based I/O routines for use with Secure Transport
26 */
27
28#include "ioSock.h"
29#include <errno.h>
30#include <stdio.h>
31
32#include <unistd.h>
33#include <sys/types.h>
34#include <netinet/in.h>
35#include <sys/socket.h>
36#include <netdb.h>
37#include <arpa/inet.h>
38#include <fcntl.h>
39
40#include <Security/SecBase.h>
41#include <time.h>
42#include <strings.h>
43
44/* debugging for this module */
45#define SSL_OT_DEBUG		1
46
47/* log errors to stdout */
48#define SSL_OT_ERRLOG		1
49
50/* trace all low-level network I/O */
51#define SSL_OT_IO_TRACE		0
52
53/* if SSL_OT_IO_TRACE, only log non-zero length transfers */
54#define SSL_OT_IO_TRACE_NZ	1
55
56/* pause after each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
57#define SSL_OT_IO_PAUSE		0
58
59/* print a stream of dots while I/O pending */
60#define SSL_OT_DOT			1
61
62/* dump some bytes of each I/O (only meaningful if SSL_OT_IO_TRACE == 1) */
63#define SSL_OT_IO_DUMP		0
64#define SSL_OT_IO_DUMP_SIZE	256
65
66/* indicate errSSLWouldBlock with a '.' */
67#define SSL_DISPL_WOULD_BLOCK	0
68
69/* general, not-too-verbose debugging */
70#if		SSL_OT_DEBUG
71#define dprintf(s)	printf s
72#else
73#define dprintf(s)
74#endif
75
76/* errors --> stdout */
77#if		SSL_OT_ERRLOG
78#define eprintf(s)	printf s
79#else
80#define eprintf(s)
81#endif
82
83/* trace completion of every r/w */
84#if		SSL_OT_IO_TRACE
85static void tprintf(
86	const char *str,
87	UInt32 req,
88	UInt32 act,
89	const UInt8 *buf)
90{
91	#if	SSL_OT_IO_TRACE_NZ
92	if(act == 0) {
93		return;
94	}
95	#endif
96	printf("%s(%u): moved (%u) bytes\n", str, (unsigned)req, (unsigned)act);
97	#if	SSL_OT_IO_DUMP
98	{
99		unsigned i;
100
101		for(i=0; i<act; i++) {
102			printf("%02X ", buf[i]);
103			if(i >= (SSL_OT_IO_DUMP_SIZE - 1)) {
104				break;
105			}
106		}
107		printf("\n");
108	}
109	#endif
110	#if SSL_OT_IO_PAUSE
111	{
112		char instr[20];
113		printf("CR to continue: ");
114		gets(instr);
115	}
116	#endif
117}
118
119#else
120#define tprintf(str, req, act, buf)
121#endif	/* SSL_OT_IO_TRACE */
122
123/*
124 * If SSL_OT_DOT, output a '.' every so often while waiting for
125 * connection. This gives user a chance to do something else with the
126 * UI.
127 */
128
129#if	SSL_OT_DOT
130
131static time_t lastTime = (time_t)0;
132#define TIME_INTERVAL		3
133
134static void outputDot()
135{
136	time_t thisTime = time(0);
137
138	if((thisTime - lastTime) >= TIME_INTERVAL) {
139		printf("."); fflush(stdout);
140		lastTime = thisTime;
141	}
142}
143#else
144#define outputDot()
145#endif
146
147
148/*
149 * One-time only init.
150 */
151void initSslOt(void)
152{
153
154}
155
156/*
157 * Connect to server.
158 */
159#define GETHOST_RETRIES		3
160
161OSStatus MakeServerConnection(
162	const char *hostName,
163	int port,
164	int nonBlocking,		// 0 or 1
165	otSocket *socketNo, 	// RETURNED
166	PeerSpec *peer)			// RETURNED
167{
168    struct sockaddr_in  addr;
169	struct hostent      *ent;
170    struct in_addr      host;
171	int					sock = 0;
172
173	*socketNo = 0;
174    if (hostName[0] >= '0' && hostName[0] <= '9')
175    {
176        host.s_addr = inet_addr(hostName);
177    }
178    else {
179		unsigned dex;
180		/* seeing a lot of soft failures here that I really don't want to track down */
181		for(dex=0; dex<GETHOST_RETRIES; dex++) {
182			if(dex != 0) {
183				printf("\n...retrying gethostbyname(%s)", hostName);
184			}
185			ent = gethostbyname(hostName);
186			if(ent != NULL) {
187				break;
188			}
189		}
190        if(ent == NULL) {
191			printf("\n***gethostbyname(%s) returned: %s\n", hostName, hstrerror(h_errno));
192            return errSecIO;
193        }
194        memcpy(&host, ent->h_addr, sizeof(struct in_addr));
195    }
196    sock = socket(AF_INET, SOCK_STREAM, 0);
197    addr.sin_addr = host;
198    addr.sin_port = htons((u_short)port);
199
200    addr.sin_family = AF_INET;
201    if (connect(sock, (struct sockaddr *) &addr, sizeof(struct sockaddr_in)) != 0)
202    {   printf("connect returned error\n");
203        return errSecIO;
204    }
205
206	if(nonBlocking) {
207		/* OK to do this after connect? */
208		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
209		if(rtn == -1) {
210			perror("fctnl(O_NONBLOCK)");
211			return errSecIO;
212		}
213	}
214
215    peer->ipAddr = addr.sin_addr.s_addr;
216    peer->port = htons((u_short)port);
217	*socketNo = (otSocket)sock;
218    return errSecSuccess;
219}
220
221/*
222 * Set up an otSocket to listen for client connections. Call once, then
223 * use multiple AcceptClientConnection calls.
224 */
225OSStatus ListenForClients(
226	int port,
227	int nonBlocking,		// 0 or 1
228	otSocket *socketNo) 	// RETURNED
229{
230	struct sockaddr_in  addr;
231    struct hostent      *ent;
232    int                 len;
233	int 				sock;
234
235    sock = socket(AF_INET, SOCK_STREAM, 0);
236	if(sock < 1) {
237		perror("socket");
238		return errSecIO;
239	}
240
241    int reuse = 1;
242    int err = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
243    if (err != 0) {
244        perror("setsockopt");
245        return err;
246    }
247
248    ent = gethostbyname("localhost");
249    if (!ent) {
250		perror("gethostbyname");
251		return errSecIO;
252    }
253    memcpy(&addr.sin_addr, ent->h_addr, sizeof(struct in_addr));
254
255    addr.sin_port = htons((u_short)port);
256    addr.sin_addr.s_addr = INADDR_ANY;
257    addr.sin_family = AF_INET;
258    len = sizeof(struct sockaddr_in);
259    if (bind(sock, (struct sockaddr *) &addr, len)) {
260		int theErr = errno;
261		perror("bind");
262		if(theErr == EADDRINUSE) {
263			return errSecOpWr;
264		}
265		else {
266			return errSecIO;
267		}
268    }
269	if(nonBlocking) {
270		int rtn = fcntl(sock, F_SETFL, O_NONBLOCK);
271		if(rtn == -1) {
272			perror("fctnl(O_NONBLOCK)");
273			return errSecIO;
274		}
275	}
276
277	for(;;) {
278		int rtn = listen(sock, 1);
279		switch(rtn) {
280			case 0:
281				*socketNo = (otSocket)sock;
282				rtn = errSecSuccess;
283				break;
284			case EWOULDBLOCK:
285				continue;
286			default:
287				perror("listen");
288				rtn = errSecIO;
289				break;
290		}
291		return rtn;
292    }
293	/* NOT REACHED */
294	return 0;
295}
296
297/*
298 * Accept a client connection.
299 */
300
301/*
302 * Currently we always get back a different peer port number on successive
303 * connections, no matter what the client is doing. To test for resumable
304 * session support, force peer port = 0.
305 */
306#define FORCE_ACCEPT_PEER_PORT_ZERO		1
307
308OSStatus AcceptClientConnection(
309	otSocket listenSock, 		// obtained from ListenForClients
310	otSocket *acceptSock, 		// RETURNED
311	PeerSpec *peer)				// RETURNED
312{
313	struct sockaddr_in  addr;
314	int					sock;
315    socklen_t           len;
316
317    len = sizeof(struct sockaddr_in);
318	do {
319		sock = accept((int)listenSock, (struct sockaddr *) &addr, &len);
320		if (sock < 0) {
321			if(errno == EAGAIN) {
322				/* nonblocking, no connection yet */
323				continue;
324			}
325			else {
326				perror("accept");
327				return errSecIO;
328			}
329		}
330		else {
331			break;
332		}
333    } while(1);
334	*acceptSock = (otSocket)sock;
335    peer->ipAddr = addr.sin_addr.s_addr;
336	#if	FORCE_ACCEPT_PEER_PORT_ZERO
337	peer->port = 0;
338	#else
339    peer->port = ntohs(addr.sin_port);
340	#endif
341    return errSecSuccess;
342}
343
344/*
345 * Shut down a connection.
346 */
347void endpointShutdown(
348	otSocket sock)
349{
350	close((int)sock);
351}
352
353/*
354 * R/W. Called out from SSL.
355 */
356OSStatus SocketRead(
357	SSLConnectionRef 	connection,
358	void 				*data, 			/* owned by
359	 									 * caller, data
360	 									 * RETURNED */
361	size_t 				*dataLength)	/* IN/OUT */
362{
363	UInt32			bytesToGo = *dataLength;
364	UInt32 			initLen = bytesToGo;
365	UInt8			*currData = (UInt8 *)data;
366	int				sock = (int)((long)connection);
367	OSStatus		rtn = errSecSuccess;
368	UInt32			bytesRead;
369	ssize_t			rrtn;
370
371	*dataLength = 0;
372
373	for(;;) {
374		bytesRead = 0;
375		/* paranoid check, ensure errno is getting written */
376		errno = -555;
377		rrtn = recv(sock, currData, bytesToGo, 0);
378		if (rrtn <= 0) {
379			if(rrtn == 0) {
380				/* closed, EOF */
381				rtn = errSSLClosedGraceful;
382				break;
383			}
384			int theErr = errno;
385			switch(theErr) {
386				case ENOENT:
387					/*
388					 * Undocumented but I definitely see this.
389					 * Non-blocking sockets only. Definitely retriable
390					 * just like an EAGAIN.
391					 */
392					dprintf(("SocketRead RETRYING on ENOENT, rrtn %d\n",
393						(int)rrtn));
394					/* normal... */
395					//rtn = errSSLWouldBlock;
396					/* ...for temp testing.... */
397					rtn = errSecIO;
398					break;
399				case ECONNRESET:
400					/* explicit peer abort */
401					rtn = errSSLClosedAbort;
402					break;
403				case EAGAIN:
404					/* nonblocking, no data */
405					rtn = errSSLWouldBlock;
406					break;
407				default:
408					dprintf(("SocketRead: read(%u) error %d, rrtn %d\n",
409						(unsigned)bytesToGo, theErr, (int)rrtn));
410					rtn = errSecIO;
411					break;
412			}
413			/* in any case, we're done with this call if rrtn <= 0 */
414			break;
415		}
416		bytesRead = rrtn;
417		bytesToGo -= bytesRead;
418		currData  += bytesRead;
419
420		if(bytesToGo == 0) {
421			/* filled buffer with incoming data, done */
422			break;
423		}
424	}
425	*dataLength = initLen - bytesToGo;
426	tprintf("SocketRead", initLen, *dataLength, (UInt8 *)data);
427
428	#if SSL_OT_DOT || (SSL_OT_DEBUG && !SSL_OT_IO_TRACE)
429	if((rtn == 0) && (*dataLength == 0)) {
430		/* keep UI alive */
431		outputDot();
432	}
433	#endif
434	#if SSL_DISPL_WOULD_BLOCK
435	if(rtn == errSSLWouldBlock) {
436		printf("."); fflush(stdout);
437	}
438	#endif
439	return rtn;
440}
441
442int oneAtATime = 0;
443
444OSStatus SocketWrite(
445	SSLConnectionRef 	connection,
446	const void	 		*data,
447	size_t 				*dataLength)	/* IN/OUT */
448{
449	size_t		bytesSent = 0;
450	int			sock = (int)((long)connection);
451	int 		length;
452	size_t		dataLen = *dataLength;
453	const UInt8 *dataPtr = (UInt8 *)data;
454	OSStatus	ortn;
455
456	if(oneAtATime && (*dataLength > 1)) {
457		size_t i;
458		size_t outLen;
459		size_t thisMove;
460
461		outLen = 0;
462		for(i=0; i<dataLen; i++) {
463			thisMove = 1;
464			ortn = SocketWrite(connection, dataPtr, &thisMove);
465			outLen += thisMove;
466			dataPtr++;
467			if(ortn) {
468				return ortn;
469			}
470		}
471		return errSecSuccess;
472	}
473	*dataLength = 0;
474
475    do {
476        length = write(sock,
477				(char*)dataPtr + bytesSent,
478				dataLen - bytesSent);
479    } while ((length > 0) &&
480			 ( (bytesSent += length) < dataLen) );
481
482	if(length <= 0) {
483		int theErr = errno;
484		switch(theErr) {
485			case EAGAIN:
486				ortn = errSSLWouldBlock; break;
487			case EPIPE:
488				ortn = errSSLClosedAbort; break;
489			default:
490				dprintf(("SocketWrite: write(%u) error %d\n",
491					  (unsigned)(dataLen - bytesSent), theErr));
492				ortn = errSecIO;
493				break;
494		}
495	}
496	else {
497		ortn = errSecSuccess;
498	}
499	tprintf("SocketWrite", dataLen, bytesSent, dataPtr);
500	*dataLength = bytesSent;
501	return ortn;
502}
503