1//
2//  ssl-51-state.c
3//  libsecurity_ssl
4//
5
6#include <stdbool.h>
7#include <pthread.h>
8#include <fcntl.h>
9#include <sys/mman.h>
10#include <unistd.h>
11
12#include <CoreFoundation/CoreFoundation.h>
13
14#include <AssertMacros.h>
15#include <Security/SecureTransportPriv.h> /* SSLSetOption */
16#include <Security/SecureTransport.h>
17#include <Security/SecPolicy.h>
18#include <Security/SecTrust.h>
19#include <Security/SecIdentity.h>
20#include <Security/SecIdentityPriv.h>
21#include <Security/SecCertificatePriv.h>
22#include <Security/SecKeyPriv.h>
23#include <Security/SecItem.h>
24#include <Security/SecRandom.h>
25
26#include <string.h>
27#include <sys/types.h>
28#include <sys/socket.h>
29#include <errno.h>
30#include <stdlib.h>
31#include <mach/mach_time.h>
32
33#if TARGET_OS_IPHONE
34#include <Security/SecRSAKey.h>
35#endif
36
37#include "ssl_regressions.h"
38#include "ssl-utils.h"
39
40#include <tls_stream_parser.h>
41#include <tls_handshake.h>
42#include <tls_record.h>
43
44#include <sys/queue.h>
45
46
47#define test_printf(x...)
48
49/* extern struct ccrng_state *ccDRBGGetRngState(); */
50#include <CommonCrypto/CommonRandomSPI.h>
51#define CCRNGSTATE ccDRBGGetRngState()
52
53struct RecQueueItem {
54    STAILQ_ENTRY(RecQueueItem) next; /* link to next queued entry or NULL */
55    tls_buffer                 record;
56    size_t                     offset; /* byte reads from this one */
57};
58
59typedef struct {
60    SSLContextRef st;
61    tls_stream_parser_t parser;
62    tls_record_t record;
63    tls_handshake_t hdsk;
64    STAILQ_HEAD(, RecQueueItem) rec_queue; // coretls server queue packet in this queue
65    int ready_count;
66} ssl_test_handle;
67
68
69static
70int tls_buffer_alloc(tls_buffer *buf, size_t length)
71{
72    buf->data = malloc(length);
73    if(!buf->data) return -ENOMEM;
74    buf->length = length;
75    return 0;
76}
77
78static
79int tls_buffer_free(tls_buffer *buf)
80{
81    free(buf->data);
82    buf->data = NULL;
83    buf->length = 0;
84    return 0;
85}
86
87#pragma mark -
88#pragma mark SecureTransport support
89
90#if 0
91static void hexdump(const char *s, const uint8_t *bytes, size_t len) {
92	size_t ix;
93    printf("socket %s(%p, %lu)\n", s, bytes, len);
94	for (ix = 0; ix < len; ++ix) {
95        if (!(ix % 16))
96            printf("\n");
97		printf("%02X ", bytes[ix]);
98	}
99	printf("\n");
100}
101#else
102#define hexdump(string, bytes, len)
103#endif
104
105static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
106{
107    ssl_test_handle *handle =(ssl_test_handle *)h;
108
109	size_t len = *length;
110	uint8_t *ptr = (uint8_t *)data;
111
112    tls_buffer buffer;
113    buffer.data = ptr;
114    buffer.length = len;
115    return tls_stream_parser_parse(handle->parser, buffer);
116}
117
118static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
119{
120    ssl_test_handle *handle =(ssl_test_handle *)h;
121
122    test_printf("%s: %p requesting len=%zd\n", __FUNCTION__, h, *length);
123
124    struct RecQueueItem *item = STAILQ_FIRST(&handle->rec_queue);
125
126    if(item==NULL) {
127        test_printf("%s: %p no data available\n", __FUNCTION__, h);
128        return errSSLWouldBlock;
129    }
130
131    size_t avail = item->record.length - item->offset;
132
133    test_printf("%s: %p %zd bytes available in %p\n", __FUNCTION__, h, avail, item);
134
135    if(avail > *length) {
136        memcpy(data, item->record.data+item->offset, *length);
137        item->offset += *length;
138    } else {
139        memcpy(data, item->record.data+item->offset, avail);
140        *length = avail;
141        STAILQ_REMOVE_HEAD(&handle->rec_queue, next);
142        tls_buffer_free(&item->record);
143        free(item);
144    }
145
146    test_printf("%s: %p %zd bytes read\n", __FUNCTION__, h, *length);
147
148
149    return 0;
150}
151
152static int process(tls_stream_parser_ctx_t ctx, tls_buffer record)
153{
154    ssl_test_handle *h = (ssl_test_handle *)ctx;
155    tls_buffer decrypted;
156    uint8_t ct;
157    int err;
158
159    test_printf("%s: %p processing %zd bytes\n", __FUNCTION__, ctx, record.length);
160
161
162    decrypted.length = tls_record_decrypted_size(h->record, record.length);
163    decrypted.data = malloc(decrypted.length);
164
165    require_action(decrypted.data, errOut, err=-ENOMEM);
166    require_noerr((err=tls_record_decrypt(h->record, record, &decrypted, &ct)), errOut);
167
168    test_printf("%s: %p decrypted %zd bytes, ct=%d\n", __FUNCTION__, ctx, decrypted.length, ct);
169
170    err=tls_handshake_process(h->hdsk, decrypted, ct);
171
172    test_printf("%s: %p processed, err=%d\n", __FUNCTION__, ctx, err);
173
174errOut:
175    return err;
176}
177
178static int
179tls_handshake_write_callback(tls_handshake_ctx_t ctx, const tls_buffer data, uint8_t content_type)
180{
181    int err = 0;
182    ssl_test_handle *handle = (ssl_test_handle *)ctx;
183
184    test_printf("%s: %p writing data ct=%d, len=%zd\n", __FUNCTION__, ctx, content_type, data.length);
185
186    struct RecQueueItem *item = malloc(sizeof(struct RecQueueItem));
187    require_action(item, errOut, err=-ENOMEM);
188
189    err=tls_buffer_alloc(&item->record, tls_record_encrypted_size(handle->record, content_type, data.length));
190    require_noerr(err, errOut);
191
192    err=tls_record_encrypt(handle->record, data, content_type, &item->record);
193    require_noerr(err, errOut);
194
195    item->offset = 0;
196
197    test_printf("%s: %p queing %zd encrypted bytes, item=%p\n", __FUNCTION__, ctx, item->record.length, item);
198
199    STAILQ_INSERT_TAIL(&handle->rec_queue, item, next);
200
201    return 0;
202
203errOut:
204    if(item) {
205        tls_buffer_free(&item->record);
206        free(item);
207    }
208    return err;
209}
210
211
212static int
213tls_handshake_message_callback(tls_handshake_ctx_t ctx, tls_handshake_message_t event)
214{
215    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
216
217    test_printf("%s: %p event = %d\n", __FUNCTION__, handle, event);
218
219    int err = 0;
220
221    return err;
222}
223
224
225
226static uint8_t appdata[] = "appdata";
227
228tls_buffer appdata_buffer = {
229    .data = appdata,
230    .length = sizeof(appdata),
231};
232
233
234static void
235tls_handshake_ready_callback(tls_handshake_ctx_t ctx, bool write, bool ready)
236{
237    ssl_test_handle *handle = (ssl_test_handle *)ctx;
238
239    test_printf("%s: %p %s ready=%d\n", __FUNCTION__, handle, write?"write":"read", ready);
240
241    if(ready) {
242        if(write) {
243            if(handle->ready_count == 0) {
244                tls_handshake_request_renegotiation(handle->hdsk);
245            } else {
246                tls_handshake_write_callback(ctx, appdata_buffer, tls_record_type_AppData);
247            }
248            handle->ready_count++;;
249        }
250    }
251}
252
253static int
254tls_handshake_set_retransmit_timer_callback(tls_handshake_ctx_t ctx, int attempt)
255{
256    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
257
258    test_printf("%s: %p attempt = %d\n", __FUNCTION__, handle, attempt);
259
260    return -1;
261}
262
263static
264int mySSLRecordInitPendingCiphersFunc(tls_handshake_ctx_t ref,
265                                      uint16_t            selectedCipher,
266                                      bool                server,
267                                      tls_buffer           key)
268{
269    ssl_test_handle *handle = (ssl_test_handle *)ref;
270
271    test_printf("%s: %p, cipher=%04x, server=%d\n", __FUNCTION__, ref, selectedCipher, server);
272    return tls_record_init_pending_ciphers(handle->record, selectedCipher, server, key);
273}
274
275static
276int mySSLRecordAdvanceWriteCipherFunc(tls_handshake_ctx_t ref)
277{
278    ssl_test_handle *handle = (ssl_test_handle *)ref;
279    test_printf("%s: %p\n", __FUNCTION__, ref);
280    return tls_record_advance_write_cipher(handle->record);
281}
282
283static
284int mySSLRecordRollbackWriteCipherFunc(tls_handshake_ctx_t ref)
285{
286    ssl_test_handle *handle = (ssl_test_handle *)ref;
287    test_printf("%s: %p\n", __FUNCTION__, ref);
288    return tls_record_rollback_write_cipher(handle->record);
289}
290
291static
292int mySSLRecordAdvanceReadCipherFunc(tls_handshake_ctx_t ref)
293{
294    ssl_test_handle *handle = (ssl_test_handle *)ref;
295    test_printf("%s: %p\n", __FUNCTION__, ref);
296    return tls_record_advance_read_cipher(handle->record);
297}
298
299static
300int mySSLRecordSetProtocolVersionFunc(tls_handshake_ctx_t ref,
301                                      tls_protocol_version  protocolVersion)
302{
303    ssl_test_handle *handle = (ssl_test_handle *)ref;
304    test_printf("%s: %p, version=%04x\n", __FUNCTION__, ref, protocolVersion);
305    return tls_record_set_protocol_version(handle->record, protocolVersion);
306}
307
308
309static int
310tls_handshake_save_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey, tls_buffer sessionData)
311{
312    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
313
314    test_printf("%s: %p\n", __FUNCTION__, handle);
315
316    return -1;
317}
318
319static int
320tls_handshake_load_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey, tls_buffer *sessionData)
321{
322    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
323
324    test_printf("%s: %p\n", __FUNCTION__, handle);
325
326    return -1;
327}
328
329static int
330tls_handshake_delete_session_data_callback(tls_handshake_ctx_t ctx, tls_buffer sessionKey)
331{
332    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
333
334    test_printf("%s: %p\n", __FUNCTION__, handle);
335
336    return -1;
337}
338
339static int
340tls_handshake_delete_all_sessions_callback(tls_handshake_ctx_t ctx)
341{
342    ssl_test_handle __unused *handle = (ssl_test_handle *)ctx;
343
344    test_printf("%s: %p\n", __FUNCTION__, handle);
345
346    return -1;
347}
348
349/* TLS callbacks */
350tls_handshake_callbacks_t tls_handshake_callbacks = {
351    .write = tls_handshake_write_callback,
352    .message = tls_handshake_message_callback,
353    .ready = tls_handshake_ready_callback,
354    .set_retransmit_timer = tls_handshake_set_retransmit_timer_callback,
355    .init_pending_cipher = mySSLRecordInitPendingCiphersFunc,
356    .advance_write_cipher = mySSLRecordAdvanceWriteCipherFunc,
357    .rollback_write_cipher = mySSLRecordRollbackWriteCipherFunc,
358    .advance_read_cipher = mySSLRecordAdvanceReadCipherFunc,
359    .set_protocol_version = mySSLRecordSetProtocolVersionFunc,
360    .load_session_data = tls_handshake_load_session_data_callback,
361    .save_session_data = tls_handshake_save_session_data_callback,
362    .delete_session_data = tls_handshake_delete_session_data_callback,
363    .delete_all_sessions = tls_handshake_delete_all_sessions_callback,
364};
365
366
367static void
368ssl_test_handle_destroy(ssl_test_handle *handle)
369{
370    if(handle) {
371        if(handle->parser) tls_stream_parser_destroy(handle->parser);
372        if(handle->record) tls_record_destroy(handle->record);
373        if(handle->hdsk) tls_handshake_destroy(handle->hdsk);
374        if(handle->st) CFRelease(handle->st);
375        free(handle);
376    }
377}
378
379static uint16_t ciphers[] = {
380    TLS_PSK_WITH_AES_128_CBC_SHA,
381};
382static int nciphers = sizeof(ciphers)/sizeof(ciphers[0]);
383
384static SSLCipherSuite ciphersuites[] = {
385    TLS_PSK_WITH_AES_128_CBC_SHA,
386};
387static int nciphersuites = sizeof(ciphersuites)/sizeof(ciphersuites[0]);
388
389
390
391static uint8_t shared_secret[] = "secret";
392
393tls_buffer psk_secret = {
394    .data = shared_secret,
395    .length = sizeof(shared_secret),
396};
397
398static ssl_test_handle *
399ssl_test_handle_create(bool server)
400{
401    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
402    SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
403
404    require(handle, out);
405    require(ctx, out);
406
407    require_noerr(SSLSetIOFuncs(ctx, (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
408    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)handle), out);
409    require_noerr(SSLSetSessionOption(ctx, kSSLSessionOptionBreakOnServerAuth, true), out);
410    require_noerr(SSLSetEnabledCiphers(ctx, ciphersuites, nciphersuites), out);
411    require_noerr(SSLSetPSKSharedSecret(ctx, shared_secret, sizeof(shared_secret)), out);
412
413    handle->st = ctx;
414    handle->parser = tls_stream_parser_create(handle, process);
415    handle->record = tls_record_create(false, CCRNGSTATE);
416    handle->hdsk = tls_handshake_create(false, true); // server.
417
418    require_noerr(tls_handshake_set_ciphersuites(handle->hdsk, ciphers, nciphers), out);
419    require_noerr(tls_handshake_set_callbacks(handle->hdsk, &tls_handshake_callbacks, handle), out);
420    require_noerr(tls_handshake_set_psk_secret(handle->hdsk, &psk_secret), out);
421    require_noerr(tls_handshake_set_renegotiation(handle->hdsk, true), out);
422
423    // Initialize the record queue
424    STAILQ_INIT(&handle->rec_queue);
425
426    return handle;
427
428out:
429    if (handle) free(handle);
430    if (ctx) CFRelease(ctx);
431    return NULL;
432}
433
434static void
435tests(void)
436{
437    OSStatus ortn;
438
439    ssl_test_handle *client;
440    SSLSessionState state;
441
442    client = ssl_test_handle_create(false);
443
444    require_action(client, out, ortn = -1);
445
446    ortn = SSLGetSessionState(client->st, &state);
447    require_noerr(ortn, out);
448    is(state, kSSLIdle, "State should be Idle");
449
450    do {
451        ortn = SSLHandshake(client->st);
452
453        require_noerr(ortn = SSLGetSessionState(client->st, &state), out);
454        test_printf("SSLHandshake returned err=%d\n", (int)ortn);
455
456        if (ortn == errSSLPeerAuthCompleted || ortn == errSSLWouldBlock)
457        {
458            require_action(state==kSSLHandshake, out, ortn = -1);
459        }
460
461    } while(ortn==errSSLWouldBlock ||
462            ortn==errSSLPeerAuthCompleted);
463
464
465    is(ortn, 0, "Unexpected SSLHandshake exit code");
466    is(state, kSSLConnected, "State should be Connected");
467
468    uint8_t buffer[128];
469    size_t available = 0;
470
471    test_printf("Initial handshake done\n");
472
473    do {
474        ortn = SSLRead(client->st, buffer, sizeof(buffer), &available);
475        require_noerr(ortn = SSLGetSessionState(client->st, &state), out);
476
477        test_printf("SSLRead returned err=%d, avail=%zd\n", (int)ortn, available);
478        if (ortn == errSSLPeerAuthCompleted)
479        {
480            require_action(state==kSSLHandshake, out, ortn = -1);
481        }
482
483    } while(available==0);
484
485    is(ortn, 0, "Unexpected SSLRead exit code");
486    is(state, kSSLConnected, "State should be Connected");
487
488
489out:
490    is(ortn, 0, "Final result is non zero");
491    ssl_test_handle_destroy(client);
492
493}
494
495int ssl_51_state(int argc, char *const *argv)
496{
497
498    plan_tests(6);
499
500    tests();
501
502    return 0;
503}
504