1/*
2 * Copyright (c) 2011-2014 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24
25#include <Security/Security.h>
26#include <Security/SecBase.h>
27
28#include "../sslViewer/sslAppUtils.h"
29
30#include <stdlib.h>
31#include <sys/types.h>
32#include <sys/socket.h>
33#include <netinet/in.h>
34#include <arpa/inet.h>
35#include <stdio.h>
36#include <errno.h>
37#include <unistd.h> /* close() */
38#include <string.h> /* memset() */
39#include <fcntl.h>
40#include <time.h>
41
42#ifdef NO_SERVER
43#include <securityd/spi.h>
44#endif
45
46#include "ssl-utils.h"
47
48#define SERVER "127.0.0.1"
49//#define SERVER "17.201.58.114"
50#define PORT 23232
51#define BUFLEN 128
52#define COUNT 10
53
54#if 0
55static void dumppacket(const unsigned char *data, unsigned long len)
56{
57    unsigned long i;
58    for(i=0;i<len;i++)
59    {
60        if((i&0xf)==0) printf("%04lx :",i);
61        printf(" %02x", data[i]);
62        if((i&0xf)==0xf) printf("\n");
63    }
64    printf("\n");
65}
66#endif
67
68/* 2K should be enough for everybody */
69#define MTU 2048
70static unsigned char readBuffer[MTU];
71static unsigned int  readOff=0;
72static size_t        readLeft=0;
73
74static
75OSStatus SocketRead(
76                    SSLConnectionRef 	connection,
77                    void 				*data,
78                    size_t 				*dataLength)
79{
80    int fd = (int)connection;
81    ssize_t len;
82    uint8_t *d=readBuffer;
83
84    if(readLeft==0)
85    {
86        len = read(fd, readBuffer, MTU);
87
88        if(len>0) {
89            readOff=0;
90            readLeft=(size_t) len;
91            printf("SocketRead: %ld bytes... epoch: %02x seq=%02x%02x\n",
92                   len, d[4], d[9], d[10]);
93
94        } else {
95            int theErr = errno;
96            switch(theErr) {
97                case EAGAIN:
98                    //printf("SocketRead: EAGAIN\n");
99                    *dataLength=0;
100                    /* nonblocking, no data */
101                    return errSSLWouldBlock;
102                default:
103                    perror("SocketRead");
104                    return errSecIO;
105            }
106        }
107    }
108
109    if(readLeft<*dataLength) {
110        *dataLength=readLeft;
111    }
112
113    memcpy(data, readBuffer+readOff, *dataLength);
114    readLeft-=*dataLength;
115    readOff+=*dataLength;
116
117    return errSecSuccess;
118
119}
120
121static
122OSStatus SocketWrite(
123                     SSLConnectionRef   connection,
124                     const void         *data,
125                     size_t 			*dataLength)	/* IN/OUT */
126{
127    int fd = (int)connection;
128    ssize_t len;
129    OSStatus err = errSecSuccess;
130    const uint8_t *d=data;
131
132#if 0
133    if((rand()&3)==1) {
134
135        /* drop 1/8th packets */
136        printf("SocketWrite: Drop %ld bytes... epoch: %02x seq=%02x%02x\n",
137               *dataLength, d[4], d[9], d[10]);
138        return errSecSuccess;
139
140    }
141#endif
142
143    len = send(fd, data, *dataLength, 0);
144
145    if(len>0) {
146        *dataLength=(size_t)len;
147        printf("SocketWrite: Sent %ld bytes... epoch: %02x seq=%02x%02x\n",
148               len, d[4], d[9], d[10]);
149        return err;
150    }
151
152    int theErr = errno;
153    switch(theErr) {
154        case EAGAIN:
155            /* nonblocking, no data */
156            err = errSSLWouldBlock;
157            break;
158        default:
159            perror("SocketWrite");
160            err = errSecIO;
161            break;
162    }
163
164    return err;
165
166}
167
168
169int main(int argc, char **argv)
170{
171    int fd;
172    struct sockaddr_in sa;
173
174    if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
175        perror("socket");
176        exit(-1);
177    }
178
179#ifdef NO_SERVER
180# if DEBUG
181    securityd_init();
182# endif
183#endif
184
185    memset((char *) &sa, 0, sizeof(sa));
186    sa.sin_family = AF_INET;
187    sa.sin_port = htons(PORT);
188    if (inet_aton(SERVER, &sa.sin_addr)==0) {
189        fprintf(stderr, "inet_aton() failed\n");
190        exit(1);
191    }
192
193    time_t seed=time(NULL);
194//    time_t seed=1298952499;
195    srand((unsigned)seed);
196    printf("Random drop initialized with seed = %lu\n", seed);
197
198    if(connect(fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
199    {
200        perror("connect");
201        return errno;
202    }
203
204    /* Change to non blocking io */
205    fcntl(fd, F_SETFL, O_NONBLOCK);
206
207    SSLConnectionRef c=(SSLConnectionRef)(intptr_t)fd;
208
209
210    OSStatus            ortn;
211    SSLContextRef       ctx = NULL;
212
213    SSLClientCertificateState certState;
214    SSLCipherSuite negCipher;
215    SSLProtocol negVersion;
216
217	/*
218	 * Set up a SecureTransport session.
219	 */
220	ortn = SSLNewDatagramContext(false, &ctx);
221	if(ortn) {
222		printSslErrStr("SSLNewDatagramContext", ortn);
223		return ortn;
224	}
225	ortn = SSLSetIOFuncs(ctx, SocketRead, SocketWrite);
226	if(ortn) {
227		printSslErrStr("SSLSetIOFuncs", ortn);
228		return ortn;
229	}
230
231    ortn = SSLSetConnection(ctx, c);
232	if(ortn) {
233		printSslErrStr("SSLSetConnection", ortn);
234		return ortn;
235	}
236
237    ortn = SSLSetMaxDatagramRecordSize(ctx, 400);
238    if(ortn) {
239		printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
240        return ortn;
241	}
242
243    /* Lets not verify the cert, which is a random test cert */
244    ortn = SSLSetEnableCertVerify(ctx, false);
245    if(ortn) {
246        printSslErrStr("SSLSetEnableCertVerify", ortn);
247        return ortn;
248    }
249
250    ortn = SSLSetCertificate(ctx, server_chain());
251    if(ortn) {
252        printSslErrStr("SSLSetCertificate", ortn);
253        return ortn;
254    }
255
256    do {
257		ortn = SSLHandshake(ctx);
258	    if(ortn == errSSLWouldBlock) {
259		/* keep UI responsive */
260		sslOutputDot();
261	    }
262    } while (ortn == errSSLWouldBlock);
263
264
265    SSLGetClientCertificateState(ctx, &certState);
266	SSLGetNegotiatedCipher(ctx, &negCipher);
267	SSLGetNegotiatedProtocolVersion(ctx, &negVersion);
268
269    int count;
270    size_t len, readLen, writeLen;
271    char buffer[BUFLEN];
272
273    count = 0;
274    while(count<COUNT) {
275        int timeout = 10000;
276
277        snprintf(buffer, BUFLEN, "Message %d", count);
278        len = strlen(buffer);
279
280        ortn=SSLWrite(ctx, buffer, len, &writeLen);
281        if(ortn) {
282            printSslErrStr("SSLWrite", ortn);
283            break;
284        }
285        printf("Wrote %lu bytes\n", writeLen);
286
287        count++;
288
289        do {
290            ortn=SSLRead(ctx, buffer, BUFLEN, &readLen);
291        } while((ortn==errSSLWouldBlock) && (timeout--));
292        if(ortn==errSSLWouldBlock) {
293            printf("Echo timeout...\n");
294            continue;
295        }
296        if(ortn) {
297                printSslErrStr("SSLRead", ortn);
298                break;
299        }
300        buffer[readLen]=0;
301        printf("Received %lu bytes: %s\n", readLen, buffer);
302
303     }
304
305    SSLClose(ctx);
306
307    SSLDisposeContext(ctx);
308
309    return ortn;
310}
311