1/*
2 *  dtlsEchoClient.c
3 *  Security
4 *
5 *  Created by Fabrice Gautier on 1/31/11.
6 *  Copyright 2011 Apple, Inc. All rights reserved.
7 *
8 */
9
10#include <Security/Security.h>
11#include <Security/SecBase.h>
12
13#include "../sslViewer/sslAppUtils.h"
14
15#include <stdlib.h>
16#include <sys/types.h>
17#include <sys/socket.h>
18#include <netinet/in.h>
19#include <arpa/inet.h>
20#include <stdio.h>
21#include <errno.h>
22#include <unistd.h> /* close() */
23#include <string.h> /* memset() */
24#include <fcntl.h>
25#include <time.h>
26
27#ifdef NO_SERVER
28#include <securityd/spi.h>
29#endif
30
31#include "ssl-utils.h"
32
33#define SERVER "127.0.0.1"
34//#define SERVER "17.201.58.114"
35#define PORT 23232
36#define BUFLEN 128
37#define COUNT 10
38
39#if 0
40static void dumppacket(const unsigned char *data, unsigned long len)
41{
42    unsigned long i;
43    for(i=0;i<len;i++)
44    {
45        if((i&0xf)==0) printf("%04lx :",i);
46        printf(" %02x", data[i]);
47        if((i&0xf)==0xf) printf("\n");
48    }
49    printf("\n");
50}
51#endif
52
53/* 2K should be enough for everybody */
54#define MTU 2048
55static unsigned char readBuffer[MTU];
56static unsigned int  readOff=0;
57static size_t        readLeft=0;
58
59static
60OSStatus SocketRead(
61                    SSLConnectionRef 	connection,
62                    void 				*data,
63                    size_t 				*dataLength)
64{
65    int fd = (int)connection;
66    ssize_t len;
67    uint8_t *d=readBuffer;
68
69    if(readLeft==0)
70    {
71        len = read(fd, readBuffer, MTU);
72
73        if(len>0) {
74            readOff=0;
75            readLeft=(size_t) len;
76            printf("SocketRead: %ld bytes... epoch: %02x seq=%02x%02x\n",
77                   len, d[4], d[9], d[10]);
78
79        } else {
80            int theErr = errno;
81            switch(theErr) {
82                case EAGAIN:
83                    //printf("SocketRead: EAGAIN\n");
84                    *dataLength=0;
85                    /* nonblocking, no data */
86                    return errSSLWouldBlock;
87                default:
88                    perror("SocketRead");
89                    return errSecIO;
90            }
91        }
92    }
93
94    if(readLeft<*dataLength) {
95        *dataLength=readLeft;
96    }
97
98    memcpy(data, readBuffer+readOff, *dataLength);
99    readLeft-=*dataLength;
100    readOff+=*dataLength;
101
102    return errSecSuccess;
103
104}
105
106static
107OSStatus SocketWrite(
108                     SSLConnectionRef   connection,
109                     const void         *data,
110                     size_t 			*dataLength)	/* IN/OUT */
111{
112    int fd = (int)connection;
113    ssize_t len;
114    OSStatus err = errSecSuccess;
115    const uint8_t *d=data;
116
117#if 0
118    if((rand()&3)==1) {
119
120        /* drop 1/8th packets */
121        printf("SocketWrite: Drop %ld bytes... epoch: %02x seq=%02x%02x\n",
122               *dataLength, d[4], d[9], d[10]);
123        return errSecSuccess;
124
125    }
126#endif
127
128    len = send(fd, data, *dataLength, 0);
129
130    if(len>0) {
131        *dataLength=(size_t)len;
132        printf("SocketWrite: Sent %ld bytes... epoch: %02x seq=%02x%02x\n",
133               len, d[4], d[9], d[10]);
134        return err;
135    }
136
137    int theErr = errno;
138    switch(theErr) {
139        case EAGAIN:
140            /* nonblocking, no data */
141            err = errSSLWouldBlock;
142            break;
143        default:
144            perror("SocketWrite");
145            err = errSecIO;
146            break;
147    }
148
149    return err;
150
151}
152
153
154int main(int argc, char **argv)
155{
156    int fd;
157    struct sockaddr_in sa;
158
159    if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
160        perror("socket");
161        exit(-1);
162    }
163
164#ifdef NO_SERVER
165# if DEBUG
166    securityd_init();
167# endif
168#endif
169
170    memset((char *) &sa, 0, sizeof(sa));
171    sa.sin_family = AF_INET;
172    sa.sin_port = htons(PORT);
173    if (inet_aton(SERVER, &sa.sin_addr)==0) {
174        fprintf(stderr, "inet_aton() failed\n");
175        exit(1);
176    }
177
178    time_t seed=time(NULL);
179//    time_t seed=1298952499;
180    srand((unsigned)seed);
181    printf("Random drop initialized with seed = %lu\n", seed);
182
183    if(connect(fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
184    {
185        perror("connect");
186        return errno;
187    }
188
189    /* Change to non blocking io */
190    fcntl(fd, F_SETFL, O_NONBLOCK);
191
192    SSLConnectionRef c=(SSLConnectionRef)(intptr_t)fd;
193
194
195    OSStatus            ortn;
196    SSLContextRef       ctx = NULL;
197
198    SSLClientCertificateState certState;
199    SSLCipherSuite negCipher;
200    SSLProtocol negVersion;
201
202	/*
203	 * Set up a SecureTransport session.
204	 */
205	ortn = SSLNewDatagramContext(false, &ctx);
206	if(ortn) {
207		printSslErrStr("SSLNewDatagramContext", ortn);
208		return ortn;
209	}
210	ortn = SSLSetIOFuncs(ctx, SocketRead, SocketWrite);
211	if(ortn) {
212		printSslErrStr("SSLSetIOFuncs", ortn);
213		return ortn;
214	}
215
216    ortn = SSLSetConnection(ctx, c);
217	if(ortn) {
218		printSslErrStr("SSLSetConnection", ortn);
219		return ortn;
220	}
221
222    ortn = SSLSetMaxDatagramRecordSize(ctx, 600);
223    if(ortn) {
224		printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
225        return ortn;
226	}
227
228    /* Lets not verify the cert, which is a random test cert */
229    ortn = SSLSetEnableCertVerify(ctx, false);
230    if(ortn) {
231        printSslErrStr("SSLSetEnableCertVerify", ortn);
232        return ortn;
233    }
234
235    ortn = SSLSetCertificate(ctx, server_chain());
236    if(ortn) {
237        printSslErrStr("SSLSetCertificate", ortn);
238        return ortn;
239    }
240
241    do {
242		ortn = SSLHandshake(ctx);
243	    if(ortn == errSSLWouldBlock) {
244		/* keep UI responsive */
245		sslOutputDot();
246	    }
247    } while (ortn == errSSLWouldBlock);
248
249
250    SSLGetClientCertificateState(ctx, &certState);
251	SSLGetNegotiatedCipher(ctx, &negCipher);
252	SSLGetNegotiatedProtocolVersion(ctx, &negVersion);
253
254    int count;
255    size_t len, readLen, writeLen;
256    char buffer[BUFLEN];
257
258    count = 0;
259    while(count<COUNT) {
260        int timeout = 10000;
261
262        snprintf(buffer, BUFLEN, "Message %d", count);
263        len = strlen(buffer);
264
265        ortn=SSLWrite(ctx, buffer, len, &writeLen);
266        if(ortn) {
267            printSslErrStr("SSLWrite", ortn);
268            break;
269        }
270        printf("Wrote %lu bytes\n", writeLen);
271
272        count++;
273
274        do {
275            ortn=SSLRead(ctx, buffer, BUFLEN, &readLen);
276        } while((ortn==errSSLWouldBlock) && (timeout--));
277        if(ortn==errSSLWouldBlock) {
278            printf("Echo timeout...\n");
279            continue;
280        }
281        if(ortn) {
282                printSslErrStr("SSLRead", ortn);
283                break;
284        }
285        buffer[readLen]=0;
286        printf("Received %lu bytes: %s\n", readLen, buffer);
287
288     }
289
290    SSLClose(ctx);
291
292    SSLDisposeContext(ctx);
293
294    return ortn;
295}
296