1/*
2 * Copyright (c) 2011-2014 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24
25#include <Security/Security.h>
26#include <Security/SecBase.h>
27
28#include "../sslViewer/sslAppUtils.h"
29
30#include <stdlib.h>
31#include <sys/types.h>
32#include <sys/socket.h>
33#include <netinet/in.h>
34#include <arpa/inet.h>
35#include <stdio.h>
36#include <errno.h>
37#include <unistd.h> /* close() */
38#include <string.h> /* memset() */
39#include <fcntl.h>
40#include <time.h>
41
42#ifdef NO_SERVER
43#include <securityd/spi.h>
44#endif
45
46#define PORT 23232
47
48#include "ssl-utils.h"
49
50static void dumppacket(const unsigned char *data, unsigned long len)
51{
52    unsigned long i;
53    for(i=0;i<len;i++)
54    {
55        if((i&0xf)==0) printf("%04lx :",i);
56        printf(" %02x", data[i]);
57        if((i&0xf)==0xf) printf("\n");
58    }
59    printf("\n");
60}
61
62
63/* 2K should be enough for everybody */
64#define MTU 2048
65static unsigned char readBuffer[MTU];
66static unsigned int  readOff=0;
67static size_t        readLeft=0;
68
69static
70OSStatus SocketRead(
71                    SSLConnectionRef 	connection,
72                    void 				*data,
73                    size_t 				*dataLength)
74{
75    int fd = (int)connection;
76    ssize_t len;
77    uint8_t *d=readBuffer;
78
79    if(readLeft==0)
80    {
81        len = read(fd, readBuffer, MTU);
82
83        if(len>0) {
84            readOff=0;
85            readLeft=(size_t) len;
86            printf("SocketRead: %ld bytes... epoch: %02x seq=%02x%02x\n",
87                   len, d[4], d[9], d[10]);
88        } else {
89            int theErr = errno;
90            switch(theErr) {
91                case EAGAIN:
92                    // printf("SocketRead: EAGAIN\n");
93                    *dataLength=0;
94                    /* nonblocking, no data */
95                    return errSSLWouldBlock;
96                default:
97                    perror("SocketRead");
98                    return errSecIO;
99            }
100        }
101    }
102
103    if(readLeft<*dataLength) {
104        *dataLength=readLeft;
105    }
106
107    memcpy(data, readBuffer+readOff, *dataLength);
108    readLeft-=*dataLength;
109    readOff+=*dataLength;
110
111
112    return errSecSuccess;
113
114}
115
116
117static
118OSStatus SocketWrite(
119                     SSLConnectionRef   connection,
120                     const void         *data,
121                     size_t 			*dataLength)	/* IN/OUT */
122{
123    int fd = (int)connection;
124    ssize_t len;
125    OSStatus err = errSecSuccess;
126    const uint8_t *d=data;
127
128#if 1
129    if((rand()&3)==1) {
130        /* drop 1/8 packets */
131        printf("SocketWrite: Drop %ld bytes... epoch: %02x seq=%02x%02x\n",
132               *dataLength, d[4], d[9], d[10]);
133        return errSecSuccess;
134    }
135#endif
136
137    len = send(fd, data, *dataLength, 0);
138
139    if(len>0) {
140        *dataLength=(size_t)len;
141
142        printf("SocketWrite: Sent %ld bytes... epoch: %02x seq=%02x%02x\n",
143               len, d[4], d[9], d[10]);
144
145        return err;
146    }
147
148    int theErr = errno;
149    switch(theErr) {
150        case EAGAIN:
151            /* nonblocking, no data */
152            err = errSSLWouldBlock;
153            break;
154        default:
155            perror("SocketWrite");
156            err = errSecIO;
157            break;
158    }
159
160    return err;
161
162}
163
164
165int main(int argc, char **argv)
166{
167    struct sockaddr_in sa; /* server address for bind */
168    struct sockaddr_in ca; /* client address for connect */
169    int fd;
170    ssize_t l;
171
172#ifdef NO_SERVER
173# if DEBUG
174    securityd_init();
175# endif
176#endif
177
178    if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
179        perror("socket");
180        return errno;
181    }
182
183    time_t seed=time(NULL);
184//    time_t seed=1298952496;
185    srand((unsigned)seed);
186    printf("Random drop initialized with seed = %lu\n", seed);
187
188    memset((char *) &sa, 0, sizeof(sa));
189    sa.sin_family = AF_INET;
190    sa.sin_port = htons(PORT);
191    sa.sin_addr.s_addr = htonl(INADDR_ANY);
192
193    if(bind (fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
194    {
195        perror("bind");
196        return errno;
197    }
198
199    printf("Waiting for first packet...\n");
200    /* PEEK only... */
201    socklen_t slen=sizeof(ca);
202    char b;
203    if((l=recvfrom(fd, &b, 1, MSG_PEEK, (struct sockaddr *)&ca, &slen))==-1)
204    {
205        perror("recvfrom");
206        return errno;
207    }
208
209    printf("Received packet from %s (%ld), connecting...\n", inet_ntoa(ca.sin_addr), l);
210
211    if(connect(fd, (struct sockaddr *)&ca, sizeof(ca))==-1)
212    {
213        perror("connect");
214        return errno;
215    }
216
217    /* Change to non blocking */
218    fcntl(fd, F_SETFL, O_NONBLOCK);
219
220
221    SSLConnectionRef c=(SSLConnectionRef)(intptr_t)fd;
222
223
224    OSStatus            ortn;
225    SSLContextRef       ctx = NULL;
226
227    SSLClientCertificateState certState;
228    SSLCipherSuite negCipher;
229
230	/*
231	 * Set up a SecureTransport session.
232	 */
233	ortn = SSLNewDatagramContext(true, &ctx);
234	if(ortn) {
235		printSslErrStr("SSLNewDatagramContext", ortn);
236		return ortn;
237	}
238
239	ortn = SSLSetIOFuncs(ctx, SocketRead, SocketWrite);
240	if(ortn) {
241		printSslErrStr("SSLSetIOFuncs", ortn);
242        return ortn;
243	}
244
245    ortn = SSLSetConnection(ctx, c);
246	if(ortn) {
247		printSslErrStr("SSLSetConnection", ortn);
248		return ortn;
249	}
250
251    ortn = SSLSetDatagramHelloCookie(ctx, &ca, 32);
252    if(ortn) {
253		printSslErrStr("SSLSetDatagramHelloCookie", ortn);
254        return ortn;
255	}
256
257    ortn = SSLSetMaxDatagramRecordSize(ctx, 400);
258    if(ortn) {
259		printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
260        return ortn;
261	}
262
263    /* Lets not verify the cert, which is a random test cert */
264    ortn = SSLSetEnableCertVerify(ctx, false);
265    if(ortn) {
266        printSslErrStr("SSLSetEnableCertVerify", ortn);
267        return ortn;
268    }
269
270    ortn = SSLSetCertificate(ctx, server_chain());
271    if(ortn) {
272        printSslErrStr("SSLSetCertificate", ortn);
273        return ortn;
274    }
275
276    ortn = SSLSetClientSideAuthenticate(ctx, kAlwaysAuthenticate);
277    if(ortn) {
278        printSslErrStr("SSLSetCertificate", ortn);
279        return ortn;
280    }
281
282    printf("Server Handshake...\n");
283    do {
284		ortn = SSLHandshake(ctx);
285	    if(ortn == errSSLWouldBlock) {
286		/* keep UI responsive */
287		sslOutputDot();
288	    }
289    } while (ortn == errSSLWouldBlock);
290
291    if(ortn) {
292        printSslErrStr("SSLHandshake", ortn);
293        return ortn;
294    }
295
296    SSLGetClientCertificateState(ctx, &certState);
297	SSLGetNegotiatedCipher(ctx, &negCipher);
298
299    printf("Server Handshake done. Cipher is %s\n", sslGetCipherSuiteString(negCipher));
300
301    unsigned char buffer[MTU];
302    size_t len, readLen;
303
304    while(1) {
305        while((ortn=SSLRead(ctx, buffer, MTU, &readLen))==errSSLWouldBlock);
306        if(ortn) {
307            printSslErrStr("SSLRead", ortn);
308            break;
309        }
310        buffer[readLen]=0;
311        printf("Received %lu bytes:\n", readLen);
312        dumppacket(buffer, readLen);
313
314        ortn=SSLWrite(ctx, buffer, readLen, &len);
315        if(ortn) {
316            printSslErrStr("SSLRead", ortn);
317            break;
318        }
319        printf("Echoing %lu bytes\n", len);
320    }
321
322    SSLDisposeContext(ctx);
323
324    return ortn;
325}
326