1//
2//  ssl-46-falsestart.c
3//  regressions
4//
5//  Created by Fabrice Gautier on 6/7/11.
6//  Copyright 2011 Apple, Inc. All rights reserved.
7//
8
9#include <stdbool.h>
10#include <pthread.h>
11#include <fcntl.h>
12#include <sys/mman.h>
13#include <unistd.h>
14#include <sys/types.h>
15#include <netinet/in.h>
16#include <sys/socket.h>
17#include <netdb.h>
18#include <arpa/inet.h>
19#include <CoreFoundation/CoreFoundation.h>
20
21#include <AssertMacros.h>
22#include <Security/SecureTransportPriv.h> /* SSLSetOption */
23#include <Security/SecureTransport.h>
24
25#include <string.h>
26#include <sys/types.h>
27#include <sys/socket.h>
28#include <errno.h>
29#include <stdlib.h>
30#include <mach/mach_time.h>
31
32
33#include "ssl_regressions.h"
34
35
36typedef struct {
37    uint32_t session_id;
38    bool is_session_resume;
39    SSLContextRef st;
40    bool is_server;
41    bool client_side_auth;
42    bool dh_anonymous;
43    int comm;
44    CFArrayRef certs;
45} ssl_test_handle;
46
47
48
49#if 0
50static void hexdump(const uint8_t *bytes, size_t len) {
51	size_t ix;
52    printf("socket write(%p, %lu)\n", bytes, len);
53	for (ix = 0; ix < len; ++ix) {
54        if (!(ix % 16))
55            printf("\n");
56		printf("%02X ", bytes[ix]);
57	}
58	printf("\n");
59}
60#else
61#define hexdump(bytes, len)
62#endif
63
64static int SocketConnect(const char *hostName, int port)
65{
66    struct sockaddr_in  addr;
67    struct in_addr      host;
68	int					sock;
69    int                 err;
70    struct hostent      *ent;
71
72    if (hostName[0] >= '0' && hostName[0] <= '9')
73    {
74        host.s_addr = inet_addr(hostName);
75    }
76    else {
77		unsigned dex;
78#define GETHOST_RETRIES 5
79		/* seeing a lot of soft failures here that I really don't want to track down */
80		for(dex=0; dex<GETHOST_RETRIES; dex++) {
81			if(dex != 0) {
82				printf("\n...retrying gethostbyname(%s)", hostName);
83			}
84			ent = gethostbyname(hostName);
85			if(ent != NULL) {
86				break;
87			}
88		}
89        if(ent == NULL) {
90			printf("\n***gethostbyname(%s) returned: %s\n", hostName, hstrerror(h_errno));
91            return -1;
92        }
93        memcpy(&host, ent->h_addr, sizeof(struct in_addr));
94    }
95
96
97    sock = socket(AF_INET, SOCK_STREAM, 0);
98    addr.sin_addr = host;
99    addr.sin_port = htons((u_short)port);
100
101    addr.sin_family = AF_INET;
102    err = connect(sock, (struct sockaddr *) &addr, sizeof(struct sockaddr_in));
103
104    if(err!=0)
105    {
106        perror("connect failed");
107        return err;
108    }
109
110    /* make non blocking */
111    fcntl(sock, F_SETFL, O_NONBLOCK);
112
113
114    return sock;
115}
116
117
118static OSStatus SocketWrite(SSLConnectionRef conn, const void *data, size_t *length)
119{
120	size_t len = *length;
121	uint8_t *ptr = (uint8_t *)data;
122
123    do {
124        ssize_t ret;
125        do {
126            hexdump(ptr, len);
127            ret = write((int)conn, ptr, len);
128            if (ret < 0)
129                perror("send");
130        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
131        if (ret > 0) {
132            len -= ret;
133            ptr += ret;
134        }
135        else
136            return -36;
137    } while (len > 0);
138
139    *length = *length - len;
140    return errSecSuccess;
141}
142
143static
144OSStatus SocketRead(
145                    SSLConnectionRef 	connection,
146                    void 				*data,
147                    size_t 				*dataLength)
148{
149    int fd = (int)connection;
150    ssize_t len;
151
152    len = read(fd, data, *dataLength);
153
154    if(len<0) {
155        int theErr = errno;
156        switch(theErr) {
157            case EAGAIN:
158                //printf("SocketRead: EAGAIN\n");
159                *dataLength=0;
160                /* nonblocking, no data */
161                return errSSLWouldBlock;
162            default:
163                perror("SocketRead");
164                return -36;
165        }
166    }
167
168    if(len<(ssize_t)*dataLength) {
169        *dataLength=len;
170        return errSSLWouldBlock;
171    }
172
173    return errSecSuccess;
174}
175
176static SSLContextRef make_ssl_ref(int sock, SSLProtocol maxprot, Boolean false_start)
177{
178    SSLContextRef ctx = NULL;
179
180    require_noerr(SSLNewContext(false, &ctx), out);
181    require_noerr(SSLSetIOFuncs(ctx,
182                                (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
183    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)(intptr_t)sock), out);
184
185    require_noerr(SSLSetSessionOption(ctx,
186                                      kSSLSessionOptionBreakOnServerAuth, true), out);
187
188    require_noerr(SSLSetSessionOption(ctx,
189                                      kSSLSessionOptionFalseStart, false_start), out);
190
191    require_noerr(SSLSetProtocolVersionMax(ctx, maxprot), out);
192
193    return ctx;
194out:
195    if (ctx)
196        SSLDisposeContext(ctx);
197    return NULL;
198}
199
200const char request[]="GET / HTTP/1.1\n\n";
201char reply[2048];
202
203static OSStatus securetransport(ssl_test_handle * ssl)
204{
205    OSStatus ortn;
206    SSLContextRef ctx = ssl->st;
207    SecTrustRef trust = NULL;
208    bool got_server_auth = false, got_client_cert_req = false;
209
210    ortn = SSLHandshake(ctx);
211    //fprintf(stderr, "Fell out of SSLHandshake with error: %ld\n", (long)ortn);
212
213    size_t sent, received;
214    const char *r=request;
215    size_t l=sizeof(request);
216
217    do {
218
219        ortn = SSLWrite(ctx, r, l, &sent);
220
221        if(ortn == errSSLWouldBlock) {
222                r+=sent;
223                l-=sent;
224        }
225
226        if (ortn == errSSLServerAuthCompleted)
227        {
228            require_string(!got_server_auth, out, "second server auth");
229            require_string(!got_client_cert_req, out, "got client cert req before server auth");
230            got_server_auth = true;
231            require_string(!trust, out, "Got errSSLServerAuthCompleted twice?");
232            /* verify peer cert chain */
233            require_noerr(SSLCopyPeerTrust(ctx, &trust), out);
234            SecTrustResultType trust_result = 0;
235            /* this won't verify without setting up a trusted anchor */
236            require_noerr(SecTrustEvaluate(trust, &trust_result), out);
237        }
238
239    } while(ortn == errSSLWouldBlock || ortn == errSSLServerAuthCompleted);
240
241    //fprintf(stderr, "\nHTTP Request Sent\n");
242
243    require_noerr_action_quiet(ortn, out, printf("SSLWrite failed with err %ld\n", (long)ortn));
244
245    require_string(got_server_auth, out, "never got server auth");
246
247    do {
248        ortn = SSLRead(ctx, reply, sizeof(reply)-1, &received);
249        //fprintf(stderr, "r"); usleep(1000);
250    } while(ortn == errSSLWouldBlock);
251
252    //fprintf(stderr, "\n");
253
254    require_noerr_action_quiet(ortn, out, printf("SSLRead failed with err %ld\n", (long)ortn));
255
256    reply[received]=0;
257
258    //fprintf(stderr, "HTTP reply:\n");
259    //fprintf(stderr, "%s\n",reply);
260
261out:
262    SSLClose(ctx);
263    SSLDisposeContext(ctx);
264    if (trust) CFRelease(trust);
265
266    return ortn;
267}
268
269
270
271static ssl_test_handle *
272ssl_test_handle_create(int comm, SSLProtocol maxprot, Boolean false_start)
273{
274    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
275    if (handle) {
276        handle->comm = comm;
277        handle->st = make_ssl_ref(comm, maxprot, false_start);
278    }
279    return handle;
280}
281
282static
283struct s_server {
284    char *host;
285    int port;
286    SSLProtocol maxprot;
287} servers[] = {
288    /* Good tls 1.2 servers */
289    {"encrypted.google.com", 443, kTLSProtocol12 },
290    {"www.amazon.com",443, kTLSProtocol12 },
291    {"www.mikestoolbox.org",443, kTLSProtocol12 },
292};
293
294#define NSERVERS (int)(sizeof(servers)/sizeof(servers[0]))
295#define NLOOPS 1
296
297static void
298tests(void)
299{
300    int p;
301    int fs;
302
303    for(p=0; p<NSERVERS;p++) {
304    for(int loops=0; loops<NLOOPS; loops++) {
305    for(fs=0;fs<2; fs++) {
306
307        ssl_test_handle *client;
308
309        int s;
310        OSStatus r;
311
312        s=SocketConnect(servers[p].host, servers[p].port);
313        if(s<0) {
314            fail("connect failed with err=%d - %s:%d (try %d)", s, servers[p].host, servers[p].port, loops);
315            break;
316        }
317
318        client = ssl_test_handle_create(s, servers[p].maxprot, fs);
319
320        r=securetransport(client);
321        ok(!r, "handshake failed with err=%ld - %s:%d (try %d), false start=%d", (long)r, servers[p].host, servers[p].port, loops, fs);
322
323        close(s);
324    } } }
325}
326
327int ssl_47_falsestart(int argc, char *const *argv)
328{
329        plan_tests(NSERVERS*NLOOPS*2);
330
331        tests();
332
333        return 0;
334}
335