1//
2//  tlssocket.c
3//  tlsnke
4//
5//  Created by Fabrice Gautier on 1/6/12.
6//  Copyright (c) 2012 Apple, Inc. All rights reserved.
7//
8
9#include <Security/SecureTransportPriv.h>
10#include <string.h>
11#include <netinet/in.h>
12#include <arpa/inet.h>
13
14#include <stdlib.h>
15#include <stdio.h>
16#include <assert.h>
17
18#include <net/kext_net.h>
19
20#include "tlssocket.h"
21#include "tlsnke.h"
22
23#include <AssertMacros.h>
24#include <errno.h>
25
26/* TLSSocket functions */
27
28static
29int TLSSocket_Read(SSLRecordContextRef ref,
30                        SSLRecord *rec)
31{
32    int socket = (int)ref;
33    int rc;
34    ssize_t sz;
35    struct sockaddr_in client_addr;
36    int avail;
37    socklen_t avail_size;
38    struct cmsghdr *cmsg;
39    tls_record_hdr_t hdr;
40    struct msghdr msg;
41    struct iovec iov;
42    int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
43    uint8_t cbuf[cbuf_len];
44
45
46    //    printf("%s: Waiting for some data...\n", __FUNCTION__);
47    /* PEEK only... */
48    char b;
49    rc = (int)recv(socket, &b, 1, MSG_PEEK);
50
51    if(rc==-1)
52    {
53        if(errno==EAGAIN)
54            return errSSLRecordWouldBlock;
55        else {
56            perror("recv");
57            return errno;
58        }
59    }
60
61    /* get the next packet size */
62    avail_size = sizeof(avail);
63    rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
64
65    check_noerr(rc);
66    check(avail_size==sizeof(avail));
67
68    if(rc || (avail_size !=sizeof(avail)))
69        return errSSLRecordInternal;
70
71    //    printf("%s: Available = %d\n", __FUNCTION__, avail);
72
73    if(avail==0)
74        return errSSLRecordWouldBlock;
75
76
77    /* Allocate a buffer */
78    rec->contents.data = malloc(avail);
79    rec->contents.length = avail;
80
81    /* read the message */
82    iov.iov_base = rec->contents.data;
83    iov.iov_len = rec->contents.length;
84    msg.msg_name = &client_addr;
85    msg.msg_namelen = sizeof(client_addr);
86    msg.msg_iov = &iov;
87    msg.msg_iovlen = 1;
88    msg.msg_control = cbuf;
89    msg.msg_controllen = cbuf_len;
90
91    sz = recvmsg(socket, &msg, 0);
92    check(sz==avail);
93
94    //    printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
95    rec->contents.length = sz;
96
97    cmsg = CMSG_FIRSTHDR(&msg);
98    check(cmsg);
99    if(!cmsg)
100        return 0;
101
102    check(cmsg->cmsg_type == SCM_TLS_HEADER);
103    check(cmsg->cmsg_level == SOL_SOCKET);
104    check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
105    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
106    check(hdr);
107
108    /* print msg info */
109    /*
110    printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
111           iov.iov_len,
112           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
113           hdr->content_type, hdr->protocol_version,
114           inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
115    */
116    rec->contentType = hdr->content_type;
117    rec->protocolVersion = hdr->protocol_version;
118
119    if(rec->contentType==SSL_RecordTypeChangeCipher) {
120        printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
121    }
122    return 0;
123}
124
125static
126int TLSSocket_Free(SSLRecordContextRef ref,
127                         SSLRecord rec)
128{
129    free(rec.contents.data);
130    return 0;
131}
132
133static
134int TLSSocket_Write(SSLRecordContextRef ref,
135                          SSLRecord rec)
136{
137    int socket = (int)ref;
138    ssize_t sz;
139
140    struct msghdr msg;
141    struct iovec iov;
142    tls_record_hdr_t hdr;
143    struct cmsghdr *cmsg;
144    int cbuf_len=CMSG_SPACE(sizeof(*hdr));
145    uint8_t cbuf[cbuf_len];
146
147    if(rec.contentType==SSL_RecordTypeChangeCipher) {
148        printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
149    }
150    // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
151
152    /* write the message */
153    iov.iov_base = rec.contents.data;
154    iov.iov_len = rec.contents.length;
155    msg.msg_name = NULL;
156    msg.msg_namelen = 0;
157    msg.msg_iov = &iov;
158    msg.msg_iovlen = 1;
159    msg.msg_control = cbuf;
160    msg.msg_controllen = cbuf_len;
161
162    cmsg = CMSG_FIRSTHDR(&msg);
163    cmsg->cmsg_level = SOL_SOCKET;
164    cmsg->cmsg_type = SCM_TLS_HEADER;
165    cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
166    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
167    hdr->content_type = rec.contentType;
168    hdr->protocol_version = rec.protocolVersion;
169
170    /* print msg info */
171    sz = sendmsg(socket, &msg, 0);
172
173    if(sz<0)
174        perror("sendmsg");
175
176    /*
177       printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
178           iov.iov_len,
179           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
180    */
181
182    check(sz==rec.contents.length);
183
184    if(sz<0)
185        return (int)sz;
186    else
187        return 0;
188}
189
190
191static
192int TLSSocket_InitPendingCiphers(SSLRecordContextRef   ref,
193                                       uint16_t              selectedCipher,
194                                       bool                  server,
195                                       SSLBuffer             key)
196{
197    int socket = (int)ref;
198    int rc;
199    char *buf;
200
201    buf = malloc(key.length+3);
202    buf[0] = selectedCipher >> 8;
203    buf[1] = selectedCipher & 0xff;
204    buf[2] = server;
205    memcpy(buf+3, key.data, key.length);
206
207    printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
208
209    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
210
211    printf("%s: rc=%d\n", __FUNCTION__, rc);
212
213    free(buf);
214
215    return rc;
216}
217
218static
219int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
220{
221    int socket = (int)ref;
222    int rc;
223    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
224
225    printf("%s: rc=%d\n", __FUNCTION__, rc);
226
227    return rc;
228}
229
230static
231int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
232{
233    int socket = (int)ref;
234    int rc;
235    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
236
237    printf("%s: rc=%d\n", __FUNCTION__, rc);
238
239    return rc;
240}
241
242static
243int TLSSocket_AdvanceReadCipher(SSLRecordContextRef    ref)
244{
245    int socket = (int)ref;
246    int rc;
247    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
248
249    printf("%s: rc=%d\n", __FUNCTION__, rc);
250
251    return rc;
252}
253
254static
255int TLSSocket_SetProtocolVersion(SSLRecordContextRef    ref,
256                                 SSLProtocolVersion     protocolVersion)
257{
258    int socket = (int)ref;
259    int rc;
260    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
261
262    printf("%s: rc=%d\n", __FUNCTION__, rc);
263
264    return rc;
265}
266
267
268static
269int TLSSocket_ServiceWriteQueue(SSLRecordContextRef    ref)
270{
271    int socket = (int)ref;
272    int rc;
273    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
274
275    return rc;
276}
277
278
279const struct SSLRecordFuncs TLSSocket_Funcs = {
280    .read                = TLSSocket_Read,
281    .write               = TLSSocket_Write,
282    .initPendingCiphers  = TLSSocket_InitPendingCiphers,
283    .advanceWriteCipher  = TLSSocket_AdvanceWriteCipher,
284    .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
285    .advanceReadCipher   = TLSSocket_AdvanceReadCipher,
286    .setProtocolVersion  = TLSSocket_SetProtocolVersion,
287    .free                = TLSSocket_Free,
288    .serviceWriteQueue   = TLSSocket_ServiceWriteQueue,
289};
290
291
292/* TLSSocket SPIs */
293
294int TLSSocket_Attach(int socket)
295{
296
297    /* Attach the TLS socket filter and return handle */
298    struct so_nke so_tlsnke;
299    int rc;
300    int handle;
301    socklen_t len;
302
303    memset(&so_tlsnke, 0, sizeof(so_tlsnke));
304    so_tlsnke.nke_handle = TLS_HANDLE_IP4;
305    rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
306    if(rc)
307        return rc;
308
309    len = sizeof(handle);
310    rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
311    if(rc)
312        return rc;
313
314    assert(len==sizeof(handle));
315
316    return handle;
317}
318
319