1//
2//  ssl-48-crashes.c
3//  libsecurity_ssl
4//
5//
6
7
8#include <stdbool.h>
9#include <pthread.h>
10#include <fcntl.h>
11#include <sys/mman.h>
12#include <unistd.h>
13
14#include <CoreFoundation/CoreFoundation.h>
15
16#include <AssertMacros.h>
17#include <Security/SecureTransportPriv.h> /* SSLSetOption */
18#include <Security/SecureTransport.h>
19#include <Security/SecPolicy.h>
20#include <Security/SecTrust.h>
21#include <Security/SecIdentity.h>
22#include <Security/SecIdentityPriv.h>
23#include <Security/SecCertificatePriv.h>
24#include <Security/SecKeyPriv.h>
25#include <Security/SecItem.h>
26#include <Security/SecRandom.h>
27
28#include <string.h>
29#include <sys/types.h>
30#include <sys/socket.h>
31#include <errno.h>
32#include <stdlib.h>
33#include <mach/mach_time.h>
34
35#if TARGET_OS_IPHONE
36#include <Security/SecRSAKey.h>
37#endif
38
39#include "ssl_regressions.h"
40#include "ssl-utils.h"
41
42#include <tls_stream_parser.h>
43
44
45typedef struct {
46    SSLContextRef st;
47    bool is_server;
48    int comm;
49    CFArrayRef certs;
50    int write_counter;
51    tls_stream_parser_t parser;
52    size_t write_size;
53} ssl_test_handle;
54
55
56#pragma mark -
57#pragma mark SecureTransport support
58
59#if 0
60static void hexdump(const char *s, const uint8_t *bytes, size_t len) {
61	size_t ix;
62    printf("socket %s(%p, %lu)\n", s, bytes, len);
63	for (ix = 0; ix < len; ++ix) {
64        if (!(ix % 16))
65            printf("\n");
66		printf("%02X ", bytes[ix]);
67	}
68	printf("\n");
69}
70#else
71#define hexdump(string, bytes, len)
72#endif
73
74static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
75{
76    ssl_test_handle *handle =(ssl_test_handle *)h;
77    int conn = handle->comm;
78	size_t len = *length;
79	uint8_t *ptr = (uint8_t *)data;
80
81    if(handle->is_server) {
82        //printf("SocketWrite: server write len=%zd\n", len);
83
84        tls_buffer buffer;
85        buffer.data = ptr;
86        buffer.length = len;
87        tls_stream_parser_parse(handle->parser, buffer);
88    }
89
90    do {
91        ssize_t ret;
92        do {
93            hexdump("write", ptr, len);
94            ret = write((int)conn, ptr, len);
95        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
96        if (ret > 0) {
97            len -= ret;
98            ptr += ret;
99        }
100        else
101            return -36;
102    } while (len > 0);
103
104    *length = *length - len;
105    return errSecSuccess;
106}
107
108static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
109{
110    const ssl_test_handle *handle=h;
111    int conn = handle->comm;
112	size_t len = *length;
113	uint8_t *ptr = (uint8_t *)data;
114
115
116    do {
117        ssize_t ret;
118        do {
119            ret = read((int)conn, ptr, len);
120        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
121        if (ret > 0) {
122            len -= ret;
123            ptr += ret;
124        }
125        else
126            return -36;
127    } while (len > 0);
128
129    if(len!=0)
130        printf("Something went wrong here... len=%d\n", (int)len);
131
132    *length = *length - len;
133    return errSecSuccess;
134}
135
136static int process(tls_stream_parser_ctx_t ctx, tls_buffer record)
137{
138    ssl_test_handle *handle = (ssl_test_handle *)ctx;
139
140    // printf("processing record len=%zd, type=%d\n", record.length, record.data[0]);
141    if(record.data[0]==tls_record_type_AppData) {
142        handle->write_counter++;
143        // printf("record count = %d\n", handle->write_counter);
144    }
145
146    return 0;
147}
148
149
150static void *securetransport_ssl_thread(void *arg)
151{
152    OSStatus ortn;
153    ssl_test_handle * ssl = (ssl_test_handle *)arg;
154    SSLContextRef ctx = ssl->st;
155    bool got_server_auth = false;
156
157    //uint64_t start = mach_absolute_time();
158    do {
159        ortn = SSLHandshake(ctx);
160
161        if (ortn == errSSLServerAuthCompleted)
162        {
163            require_string(!got_server_auth, out, "second server auth");
164            got_server_auth = true;
165        }
166    } while (ortn == errSSLWouldBlock
167             || ortn == errSSLServerAuthCompleted);
168
169    require_noerr_action_quiet(ortn, out,
170                               fprintf(stderr, "Fell out of SSLHandshake with error: %d\n", (int)ortn));
171
172    unsigned char ibuf[90000], obuf[45000];
173
174    if (ssl->is_server) {
175        size_t len;
176        SecRandomCopyBytes(kSecRandomDefault, ssl->write_size, obuf);
177        require_noerr(ortn = SSLWrite(ctx, obuf, ssl->write_size, &len), out);
178        require_action(len == ssl->write_size, out, ortn = -1);
179        require_noerr(ortn = SSLWrite(ctx, obuf, ssl->write_size, &len), out);
180        require_action(len == ssl->write_size, out, ortn = -1);
181    } else {
182        size_t len = ssl->write_size*2;
183        size_t olen;
184        unsigned char *p = ibuf;
185        while (len) {
186            require_noerr(ortn = SSLRead(ctx, p, len, &olen), out);
187            len -= olen;
188            p += olen;
189        }
190    }
191
192out:
193    SSLClose(ctx);
194    CFRelease(ctx);
195    close(ssl->comm);
196    pthread_exit((void *)(intptr_t)ortn);
197    return NULL;
198}
199
200static void
201ssl_test_handle_destroy(ssl_test_handle *handle)
202{
203    if(handle) {
204        if(handle->parser) tls_stream_parser_destroy(handle->parser);
205        free(handle);
206    }
207}
208
209static ssl_test_handle *
210ssl_test_handle_create(bool server, int comm, CFArrayRef certs)
211{
212    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
213    SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
214
215    require(handle, out);
216    require(ctx, out);
217
218    require_noerr(SSLSetIOFuncs(ctx,
219                                (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
220    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)handle), out);
221
222    if (server)
223        require_noerr(SSLSetCertificate(ctx, certs), out);
224
225    require_noerr(SSLSetSessionOption(ctx,
226                                      kSSLSessionOptionBreakOnServerAuth, true), out);
227
228    /* Tell SecureTransport to not check certs itself: it will break out of the
229     handshake to let us take care of it instead. */
230    require_noerr(SSLSetEnableCertVerify(ctx, false), out);
231
232    handle->is_server = server;
233    handle->comm = comm;
234    handle->certs = certs;
235    handle->st = ctx;
236    handle->write_counter = 0;
237    handle->parser = tls_stream_parser_create(handle, process);
238
239    return handle;
240
241out:
242    if (handle) free(handle);
243    if (ctx) CFRelease(ctx);
244    return NULL;
245}
246
247static SSLCipherSuite ciphers[] = {
248    TLS_RSA_WITH_AES_128_CBC_SHA,
249    //FIXME: re-enable this test when its fixed.
250    //TLS_RSA_WITH_RC4_128_SHA,
251};
252static int nciphers = sizeof(ciphers)/sizeof(ciphers[0]);
253
254static SSLProtocolVersion versions[] = {
255    kSSLProtocol3,
256    kTLSProtocol1,
257    kTLSProtocol11,
258    kTLSProtocol12,
259};
260static int nversions = sizeof(versions)/sizeof(versions[0]);
261
262// { write size, expected count when nosplit, expected count when split }
263static size_t wsizes[][3] = {
264    {       1,  2,  2 },
265    {       2,  2,  3 },
266    {       3,  2,  3 },
267    {       4,  2,  3 },
268    {   16384,  2,  3 },
269    {   16385,  4,  4 },
270    {   16386,  4,  6 },
271    {   16387,  4,  7 },
272    {   16388,  4,  7 },
273    {   32768,  4,  7 },
274    {   32769,  6,  7 },
275    {   32770,  6,  8 },
276    {   32771,  6, 10 },
277    {   32772,  6, 11 },
278    {   32773,  6, 11 },
279};
280static int nwsizes = sizeof(wsizes)/sizeof(wsizes[0]);
281
282static void
283tests(void)
284{
285    pthread_t client_thread, server_thread;
286    CFArrayRef server_certs = server_chain();
287    ok(server_certs, "got server certs");
288
289    int i,j,k,s;
290
291    for(i=0; i<nciphers; i++)
292    for(j=0; j<nversions; j++)
293    for(k=0; k<nwsizes; k++)
294    for(s=0; s<3; s++)
295    {
296        int sp[2];
297        if (socketpair(AF_UNIX, SOCK_STREAM, 0, sp)) exit(errno);
298        fcntl(sp[0], F_SETNOSIGPIPE, 1);
299        fcntl(sp[1], F_SETNOSIGPIPE, 1);
300
301        ssl_test_handle *server, *client;
302
303        server = ssl_test_handle_create(true /*server*/, sp[0], server_certs);
304        client = ssl_test_handle_create(false/*client*/, sp[1], NULL);
305
306        server->write_size = wsizes[k][0];
307        client->write_size = wsizes[k][0];
308
309        require(client, out);
310        require(server, out);
311
312        require_noerr(SSLSetProtocolVersionMax(client->st, versions[j]), out);
313        require_noerr(SSLSetEnabledCiphers(client->st, &ciphers[i], 1), out);
314        if(s) {
315            // s=0: default (should be enabled)
316            // s=1: explicit enable
317            // s=2: expliciti disable
318            require_noerr(SSLSetSessionOption(server->st, kSSLSessionOptionSendOneByteRecord, (s==1)?true:false), out);
319        }
320        // printf("**** Test Case: i=%d, j=%d, k=%d (%zd), s=%d ****\n", i, j, k, wsizes[k][0], s);
321
322        pthread_create(&client_thread, NULL, securetransport_ssl_thread, client);
323        pthread_create(&server_thread, NULL, securetransport_ssl_thread, server);
324
325        int server_err, client_err;
326        pthread_join(client_thread, (void*)&client_err);
327        pthread_join(server_thread, (void*)&server_err);
328
329        ok(!server_err, "Server error = %d", server_err);
330        ok(!client_err, "Client error = %d", client_err);
331
332        /* one byte split is expected only for AES when using TLS 1.0 or lower, and when not disabled */
333        bool expected_split = (i==0) && (s!=2) && (versions[j]<=kTLSProtocol1);
334        int expected_count = (int)(expected_split ? wsizes[k][2]: wsizes[k][1]);
335
336        is(server->write_counter, expected_count, "wrong number of data records");
337
338        // fprintf(stderr, "Server write counter = %d, expected %d\n", server->write_counter, expected_count);
339
340out:
341        ssl_test_handle_destroy(client);
342        ssl_test_handle_destroy(server);
343
344    }
345    CFReleaseNull(server_certs);
346}
347
348int ssl_48_split(int argc, char *const *argv)
349{
350
351    plan_tests(1 + nciphers*nversions*nwsizes*3 * 3);
352
353
354    tests();
355
356    return 0;
357}
358