1//
2//  ssl-49-sni.c
3//  libsecurity_ssl
4//
5//
6
7
8#include <stdbool.h>
9#include <pthread.h>
10#include <fcntl.h>
11#include <sys/mman.h>
12#include <unistd.h>
13
14#include <CoreFoundation/CoreFoundation.h>
15
16#include <AssertMacros.h>
17#include <Security/SecureTransportPriv.h> /* SSLSetOption */
18#include <Security/SecureTransport.h>
19#include <Security/SecPolicy.h>
20#include <Security/SecTrust.h>
21#include <Security/SecIdentity.h>
22#include <Security/SecIdentityPriv.h>
23#include <Security/SecCertificatePriv.h>
24#include <Security/SecKeyPriv.h>
25#include <Security/SecItem.h>
26#include <Security/SecRandom.h>
27
28#include <string.h>
29#include <sys/types.h>
30#include <sys/socket.h>
31#include <errno.h>
32#include <stdlib.h>
33#include <mach/mach_time.h>
34
35#if TARGET_OS_IPHONE
36#include <Security/SecRSAKey.h>
37#endif
38
39#include "ssl_regressions.h"
40#include "ssl-utils.h"
41
42#include <tls_stream_parser.h>
43#include <tls_handshake.h>
44#include <tls_record.h>
45
46/* extern struct ccrng_state *ccDRBGGetRngState(); */
47#include <CommonCrypto/CommonRandomSPI.h>
48#define CCRNGSTATE ccDRBGGetRngState()
49
50
51typedef struct {
52    SSLContextRef st;
53    tls_stream_parser_t parser;
54    tls_record_t record;
55    tls_handshake_t hdsk;
56} ssl_test_handle;
57
58
59#pragma mark -
60#pragma mark SecureTransport support
61
62#if 0
63static void hexdump(const char *s, const uint8_t *bytes, size_t len) {
64	size_t ix;
65    printf("socket %s(%p, %lu)\n", s, bytes, len);
66	for (ix = 0; ix < len; ++ix) {
67        if (!(ix % 16))
68            printf("\n");
69		printf("%02X ", bytes[ix]);
70	}
71	printf("\n");
72}
73#else
74#define hexdump(string, bytes, len)
75#endif
76
77static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
78{
79    ssl_test_handle *handle =(ssl_test_handle *)h;
80
81	size_t len = *length;
82	uint8_t *ptr = (uint8_t *)data;
83
84    tls_buffer buffer;
85    buffer.data = ptr;
86    buffer.length = len;
87    return tls_stream_parser_parse(handle->parser, buffer);
88}
89
90static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
91{
92    return -36;
93}
94
95static int process(tls_stream_parser_ctx_t ctx, tls_buffer record)
96{
97    ssl_test_handle *h = (ssl_test_handle *)ctx;
98    tls_buffer decrypted;
99    uint8_t ct;
100    int err;
101
102    decrypted.length = tls_record_decrypted_size(h->record, record.length);
103    decrypted.data = malloc(decrypted.length);
104
105    require_action(decrypted.data, errOut, err=ENOMEM);
106    require_noerr((err=tls_record_decrypt(h->record, record, &decrypted, &ct)), errOut);
107    err=tls_handshake_process(h->hdsk, decrypted, ct);
108
109errOut:
110    return err;
111}
112
113
114static int
115tls_handshake_message_callback(tls_handshake_ctx_t ctx, tls_handshake_message_t event)
116{
117    int err = 0;
118
119    switch(event) {
120        case tls_handshake_message_client_hello:
121            err = -1234;
122            break;
123        default:
124            err = -1;
125            break;
126    }
127
128    return err;
129}
130
131static int
132tls_handshake_set_protocol_version(tls_handshake_ctx_t ctx, tls_protocol_version protocolVersion)
133{
134    return 0;
135}
136
137static int
138tls_handshake_write(tls_handshake_ctx_t ctx, const tls_buffer data, uint8_t content_type)
139{
140    return -36;
141}
142
143static int
144tls_handshake_set_retransmit_timer(tls_handshake_ctx_t ctx, int attempt)
145{
146    return -1;
147}
148
149
150static
151tls_handshake_callbacks_t tls_handshake_callbacks = {
152    .message = tls_handshake_message_callback,
153    .set_protocol_version = tls_handshake_set_protocol_version,
154    .write = tls_handshake_write,
155    .set_retransmit_timer = tls_handshake_set_retransmit_timer,
156};
157
158
159static void
160ssl_test_handle_destroy(ssl_test_handle *handle)
161{
162    if(handle) {
163        if(handle->parser) tls_stream_parser_destroy(handle->parser);
164        if(handle->record) tls_record_destroy(handle->record);
165        if(handle->hdsk) tls_handshake_destroy(handle->hdsk);
166        if(handle->st) CFRelease(handle->st);
167        free(handle);
168    }
169}
170
171static uint16_t ciphers[] = {
172    TLS_RSA_WITH_AES_128_CBC_SHA,
173    //FIXME: re-enable this test when its fixed.
174    //TLS_RSA_WITH_RC4_128_SHA,
175};
176static int nciphers = sizeof(ciphers)/sizeof(ciphers[0]);
177
178
179static ssl_test_handle *
180ssl_test_handle_create(bool server)
181{
182    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
183    SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
184
185    require(handle, out);
186    require(ctx, out);
187
188    require_noerr(SSLSetIOFuncs(ctx,
189                                (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
190    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)handle), out);
191
192    require_noerr(SSLSetSessionOption(ctx,
193                                      kSSLSessionOptionBreakOnServerAuth, true), out);
194
195    /* Tell SecureTransport to not check certs itself: it will break out of the
196     handshake to let us take care of it instead. */
197    require_noerr(SSLSetEnableCertVerify(ctx, false), out);
198
199    handle->st = ctx;
200    handle->parser = tls_stream_parser_create(handle, process);
201    handle->record = tls_record_create(false, CCRNGSTATE);
202    handle->hdsk = tls_handshake_create(false, true); // server.
203    tls_handshake_set_ciphersuites(handle->hdsk, ciphers, nciphers);
204
205    tls_handshake_set_callbacks(handle->hdsk, &tls_handshake_callbacks, handle);
206
207    return handle;
208
209out:
210    if (handle) free(handle);
211    if (ctx) CFRelease(ctx);
212    return NULL;
213}
214
215static SSLProtocolVersion versions[] = {
216    kSSLProtocol3,
217    kTLSProtocol1,
218    kTLSProtocol11,
219    kTLSProtocol12,
220};
221static int nversions = sizeof(versions)/sizeof(versions[0]);
222
223static char peername[] = "peername";
224
225static void
226tests(void)
227{
228    int j;
229    OSStatus ortn;
230
231    for(j=0; j<nversions; j++)
232    {
233        ssl_test_handle *client;
234        const tls_buffer *sni;
235
236        client = ssl_test_handle_create(false);
237
238        require(client, out);
239
240        require_noerr(SSLSetProtocolVersionMax(client->st, versions[j]), out);
241        require_noerr(SSLSetPeerDomainName(client->st, peername, sizeof(peername)), out);
242
243        ortn = SSLHandshake(client->st);
244
245        ok(ortn==-1234, "Unexpected Handshake exit code");
246
247        sni = tls_handshake_get_sni_hostname(client->hdsk);
248
249        if(versions[j]==kSSLProtocol3) {
250            ok(sni==NULL || sni->data==NULL,"Unexpected SNI");
251        } else {
252            ok(sni!=NULL && sni->data!=NULL &&
253               sni->length == sizeof(peername) &&
254               (memcmp(sni->data, peername, sizeof(peername))==0),
255               "SNI does not match");
256        }
257
258out:
259        ssl_test_handle_destroy(client);
260
261    }
262}
263
264int ssl_49_sni(int argc, char *const *argv)
265{
266
267    plan_tests(8);
268
269
270    tests();
271
272    return 0;
273}
274