1/*
2 * Copyright (c) 2011,2013-2014 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24
25#include <stdbool.h>
26#include <pthread.h>
27#include <fcntl.h>
28#include <sys/mman.h>
29#include <unistd.h>
30#include <sys/types.h>
31#include <netinet/in.h>
32#include <sys/socket.h>
33#include <netdb.h>
34#include <arpa/inet.h>
35#include <CoreFoundation/CoreFoundation.h>
36
37#include <AssertMacros.h>
38#include <Security/SecureTransportPriv.h> /* SSLSetOption */
39#include <Security/SecureTransport.h>
40
41#include <string.h>
42#include <sys/types.h>
43#include <sys/socket.h>
44#include <errno.h>
45#include <stdlib.h>
46#include <mach/mach_time.h>
47
48
49#include "ssl_regressions.h"
50
51
52typedef struct {
53    uint32_t session_id;
54    bool is_session_resume;
55    SSLContextRef st;
56    bool is_server;
57    bool client_side_auth;
58    bool dh_anonymous;
59    int comm;
60    CFArrayRef certs;
61} ssl_test_handle;
62
63
64
65#if 0
66static void hexdump(const uint8_t *bytes, size_t len) {
67	size_t ix;
68    printf("socket write(%p, %lu)\n", bytes, len);
69	for (ix = 0; ix < len; ++ix) {
70        if (!(ix % 16))
71            printf("\n");
72		printf("%02X ", bytes[ix]);
73	}
74	printf("\n");
75}
76#else
77#define hexdump(bytes, len)
78#endif
79
80static int SocketConnect(const char *hostName, int port)
81{
82    struct sockaddr_in  addr;
83    struct in_addr      host;
84	int					sock;
85    int                 err;
86    struct hostent      *ent;
87
88    if (hostName[0] >= '0' && hostName[0] <= '9')
89    {
90        host.s_addr = inet_addr(hostName);
91    }
92    else {
93		unsigned dex;
94#define GETHOST_RETRIES 5
95		/* seeing a lot of soft failures here that I really don't want to track down */
96		for(dex=0; dex<GETHOST_RETRIES; dex++) {
97			if(dex != 0) {
98				printf("\n...retrying gethostbyname(%s)", hostName);
99			}
100			ent = gethostbyname(hostName);
101			if(ent != NULL) {
102				break;
103			}
104		}
105        if(ent == NULL) {
106			printf("\n***gethostbyname(%s) returned: %s\n", hostName, hstrerror(h_errno));
107            return -1;
108        }
109        memcpy(&host, ent->h_addr, sizeof(struct in_addr));
110    }
111
112
113    sock = socket(AF_INET, SOCK_STREAM, 0);
114    addr.sin_addr = host;
115    addr.sin_port = htons((u_short)port);
116
117    addr.sin_family = AF_INET;
118    err = connect(sock, (struct sockaddr *) &addr, sizeof(struct sockaddr_in));
119
120    if(err!=0)
121    {
122        perror("connect failed");
123        return err;
124    }
125
126    /* make non blocking */
127    fcntl(sock, F_SETFL, O_NONBLOCK);
128
129
130    return sock;
131}
132
133
134static OSStatus SocketWrite(SSLConnectionRef conn, const void *data, size_t *length)
135{
136	size_t len = *length;
137	uint8_t *ptr = (uint8_t *)data;
138
139    do {
140        ssize_t ret;
141        do {
142            hexdump(ptr, len);
143            ret = write((int)conn, ptr, len);
144            if (ret < 0)
145                perror("send");
146        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
147        if (ret > 0) {
148            len -= ret;
149            ptr += ret;
150        }
151        else
152            return -36;
153    } while (len > 0);
154
155    *length = *length - len;
156    return errSecSuccess;
157}
158
159static
160OSStatus SocketRead(
161                    SSLConnectionRef 	connection,
162                    void 				*data,
163                    size_t 				*dataLength)
164{
165    int fd = (int)connection;
166    ssize_t len;
167
168    len = read(fd, data, *dataLength);
169
170    if(len<0) {
171        int theErr = errno;
172        switch(theErr) {
173            case EAGAIN:
174                //printf("SocketRead: EAGAIN\n");
175                *dataLength=0;
176                /* nonblocking, no data */
177                return errSSLWouldBlock;
178            default:
179                perror("SocketRead");
180                return -36;
181        }
182    }
183
184    if(len<(ssize_t)*dataLength) {
185        *dataLength=len;
186        return errSSLWouldBlock;
187    }
188
189    return errSecSuccess;
190}
191
192static SSLContextRef make_ssl_ref(int sock, SSLProtocol maxprot, Boolean false_start)
193{
194    SSLContextRef ctx = NULL;
195
196    require_noerr(SSLNewContext(false, &ctx), out);
197    require_noerr(SSLSetIOFuncs(ctx,
198                                (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
199    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)(intptr_t)sock), out);
200
201    require_noerr(SSLSetSessionOption(ctx,
202                                      kSSLSessionOptionBreakOnServerAuth, true), out);
203
204    require_noerr(SSLSetSessionOption(ctx,
205                                      kSSLSessionOptionFalseStart, false_start), out);
206
207    require_noerr(SSLSetProtocolVersionMax(ctx, maxprot), out);
208
209    return ctx;
210out:
211    if (ctx)
212        SSLDisposeContext(ctx);
213    return NULL;
214}
215
216const char request[]="GET / HTTP/1.1\n\n";
217char reply[2048];
218
219static OSStatus securetransport(ssl_test_handle * ssl)
220{
221    OSStatus ortn;
222    SSLContextRef ctx = ssl->st;
223    SecTrustRef trust = NULL;
224    bool got_server_auth = false, got_client_cert_req = false;
225
226    ortn = SSLHandshake(ctx);
227    //fprintf(stderr, "Fell out of SSLHandshake with error: %ld\n", (long)ortn);
228
229    size_t sent, received;
230    const char *r=request;
231    size_t l=sizeof(request);
232
233    do {
234
235        ortn = SSLWrite(ctx, r, l, &sent);
236
237        if(ortn == errSSLWouldBlock) {
238                r+=sent;
239                l-=sent;
240        }
241
242        if (ortn == errSSLServerAuthCompleted)
243        {
244            require_string(!got_server_auth, out, "second server auth");
245            require_string(!got_client_cert_req, out, "got client cert req before server auth");
246            got_server_auth = true;
247            require_string(!trust, out, "Got errSSLServerAuthCompleted twice?");
248            /* verify peer cert chain */
249            require_noerr(SSLCopyPeerTrust(ctx, &trust), out);
250            SecTrustResultType trust_result = 0;
251            /* this won't verify without setting up a trusted anchor */
252            require_noerr(SecTrustEvaluate(trust, &trust_result), out);
253        }
254    } while(ortn == errSSLWouldBlock || ortn == errSSLServerAuthCompleted);
255
256    //fprintf(stderr, "\nHTTP Request Sent\n");
257
258    require_noerr_action_quiet(ortn, out, printf("SSLWrite failed with err %ld\n", (long)ortn));
259
260    require_string(got_server_auth, out, "never got server auth");
261
262    do {
263        ortn = SSLRead(ctx, reply, sizeof(reply)-1, &received);
264        //fprintf(stderr, "r"); usleep(1000);
265    } while(ortn == errSSLWouldBlock);
266
267    //fprintf(stderr, "\n");
268
269    require_noerr_action_quiet(ortn, out, printf("SSLRead failed with err %ld\n", (long)ortn));
270
271    reply[received]=0;
272
273    //fprintf(stderr, "HTTP reply:\n");
274    //fprintf(stderr, "%s\n",reply);
275
276out:
277    SSLClose(ctx);
278    SSLDisposeContext(ctx);
279    if (trust) CFRelease(trust);
280
281    return ortn;
282}
283
284static ssl_test_handle *
285ssl_test_handle_create(int comm, SSLProtocol maxprot, Boolean false_start)
286{
287    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
288    if (handle) {
289        handle->comm = comm;
290        handle->st = make_ssl_ref(comm, maxprot, false_start);
291    }
292    return handle;
293}
294
295static
296struct s_server {
297    char *host;
298    int port;
299    SSLProtocol maxprot;
300} servers[] = {
301    /* Good tls 1.2 servers */
302    {"encrypted.google.com", 443, kTLSProtocol12 },
303    {"www.amazon.com",443, kTLSProtocol12 },
304    //{"www.mikestoolbox.org",443, kTLSProtocol12 },
305};
306
307#define NSERVERS (int)(sizeof(servers)/sizeof(servers[0]))
308#define NLOOPS 1
309
310static void
311tests(void)
312{
313    int p;
314    int fs;
315
316    for(p=0; p<NSERVERS;p++) {
317    for(int loops=0; loops<NLOOPS; loops++) {
318    for(fs=0;fs<2; fs++) {
319
320        ssl_test_handle *client;
321
322        int s;
323        OSStatus r;
324
325        s=SocketConnect(servers[p].host, servers[p].port);
326        if(s<0) {
327            fail("connect failed with err=%d - %s:%d (try %d)", s, servers[p].host, servers[p].port, loops);
328            break;
329        }
330
331        client = ssl_test_handle_create(s, servers[p].maxprot, fs);
332
333        r=securetransport(client);
334        ok(!r, "handshake failed with err=%ld - %s:%d (try %d), false start=%d", (long)r, servers[p].host, servers[p].port, loops, fs);
335
336        close(s);
337        free(client);
338    } } }
339}
340
341int ssl_47_falsestart(int argc, char *const *argv)
342{
343        plan_tests(NSERVERS*NLOOPS*2);
344
345        tests();
346
347        return 0;
348}
349