1// 2// tlssocket.c 3// tlsnke 4// 5// Created by Fabrice Gautier on 1/6/12. 6// Copyright (c) 2012 Apple, Inc. All rights reserved. 7// 8 9#include <Security/SecureTransportPriv.h> 10#include <string.h> 11#include <netinet/in.h> 12#include <arpa/inet.h> 13 14#include <stdlib.h> 15#include <stdio.h> 16#include <assert.h> 17 18#include <net/kext_net.h> 19 20#include "tlssocket.h" 21#include "tlsnke.h" 22 23#include <AssertMacros.h> 24#include <errno.h> 25 26/* TLSSocket functions */ 27 28static 29int TLSSocket_Read(SSLRecordContextRef ref, 30 SSLRecord *rec) 31{ 32 int socket = (int)ref; 33 int rc; 34 ssize_t sz; 35 struct sockaddr_in client_addr; 36 int avail; 37 socklen_t avail_size; 38 struct cmsghdr *cmsg; 39 tls_record_hdr_t hdr; 40 struct msghdr msg; 41 struct iovec iov; 42 int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024; 43 uint8_t cbuf[cbuf_len]; 44 45 46 // printf("%s: Waiting for some data...\n", __FUNCTION__); 47 /* PEEK only... */ 48 char b; 49 rc = (int)recv(socket, &b, 1, MSG_PEEK); 50 51 if(rc==-1) 52 { 53 if(errno==EAGAIN) 54 return errSSLRecordWouldBlock; 55 else { 56 perror("recv"); 57 return errno; 58 } 59 } 60 61 /* get the next packet size */ 62 avail_size = sizeof(avail); 63 rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size); 64 65 check_noerr(rc); 66 check(avail_size==sizeof(avail)); 67 68 if(rc || (avail_size !=sizeof(avail))) 69 return errSSLRecordInternal; 70 71 // printf("%s: Available = %d\n", __FUNCTION__, avail); 72 73 if(avail==0) 74 return errSSLRecordWouldBlock; 75 76 77 /* Allocate a buffer */ 78 rec->contents.data = malloc(avail); 79 rec->contents.length = avail; 80 81 /* read the message */ 82 iov.iov_base = rec->contents.data; 83 iov.iov_len = rec->contents.length; 84 msg.msg_name = &client_addr; 85 msg.msg_namelen = sizeof(client_addr); 86 msg.msg_iov = &iov; 87 msg.msg_iovlen = 1; 88 msg.msg_control = cbuf; 89 msg.msg_controllen = cbuf_len; 90 91 sz = recvmsg(socket, &msg, 0); 92 check(sz==avail); 93 94 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags); 95 rec->contents.length = sz; 96 97 cmsg = CMSG_FIRSTHDR(&msg); 98 check(cmsg); 99 if(!cmsg) 100 return 0; 101 102 check(cmsg->cmsg_type == SCM_TLS_HEADER); 103 check(cmsg->cmsg_level == SOL_SOCKET); 104 check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr))); 105 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg); 106 check(hdr); 107 108 /* print msg info */ 109 /* 110 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc, 111 iov.iov_len, 112 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type, 113 hdr->content_type, hdr->protocol_version, 114 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port)); 115 */ 116 rec->contentType = hdr->content_type; 117 rec->protocolVersion = hdr->protocol_version; 118 119 if(rec->contentType==SSL_RecordTypeChangeCipher) { 120 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__); 121 } 122 return 0; 123} 124 125static 126int TLSSocket_Free(SSLRecordContextRef ref, 127 SSLRecord rec) 128{ 129 free(rec.contents.data); 130 return 0; 131} 132 133static 134int TLSSocket_Write(SSLRecordContextRef ref, 135 SSLRecord rec) 136{ 137 int socket = (int)ref; 138 ssize_t sz; 139 140 struct msghdr msg; 141 struct iovec iov; 142 tls_record_hdr_t hdr; 143 struct cmsghdr *cmsg; 144 int cbuf_len=CMSG_SPACE(sizeof(*hdr)); 145 uint8_t cbuf[cbuf_len]; 146 147 if(rec.contentType==SSL_RecordTypeChangeCipher) { 148 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__); 149 } 150 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length); 151 152 /* write the message */ 153 iov.iov_base = rec.contents.data; 154 iov.iov_len = rec.contents.length; 155 msg.msg_name = NULL; 156 msg.msg_namelen = 0; 157 msg.msg_iov = &iov; 158 msg.msg_iovlen = 1; 159 msg.msg_control = cbuf; 160 msg.msg_controllen = cbuf_len; 161 162 cmsg = CMSG_FIRSTHDR(&msg); 163 cmsg->cmsg_level = SOL_SOCKET; 164 cmsg->cmsg_type = SCM_TLS_HEADER; 165 cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr)); 166 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg); 167 hdr->content_type = rec.contentType; 168 hdr->protocol_version = rec.protocolVersion; 169 170 /* print msg info */ 171 sz = sendmsg(socket, &msg, 0); 172 173 if(sz<0) 174 perror("sendmsg"); 175 176 /* 177 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz, 178 iov.iov_len, 179 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type); 180 */ 181 182 check(sz==rec.contents.length); 183 184 if(sz<0) 185 return (int)sz; 186 else 187 return 0; 188} 189 190 191static 192int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref, 193 uint16_t selectedCipher, 194 bool server, 195 SSLBuffer key) 196{ 197 int socket = (int)ref; 198 int rc; 199 char *buf; 200 201 buf = malloc(key.length+3); 202 buf[0] = selectedCipher >> 8; 203 buf[1] = selectedCipher & 0xff; 204 buf[2] = server; 205 memcpy(buf+3, key.data, key.length); 206 207 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length); 208 209 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3)); 210 211 printf("%s: rc=%d\n", __FUNCTION__, rc); 212 213 free(buf); 214 215 return rc; 216} 217 218static 219int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref) 220{ 221 int socket = (int)ref; 222 int rc; 223 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0); 224 225 printf("%s: rc=%d\n", __FUNCTION__, rc); 226 227 return rc; 228} 229 230static 231int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref) 232{ 233 int socket = (int)ref; 234 int rc; 235 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0); 236 237 printf("%s: rc=%d\n", __FUNCTION__, rc); 238 239 return rc; 240} 241 242static 243int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref) 244{ 245 int socket = (int)ref; 246 int rc; 247 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0); 248 249 printf("%s: rc=%d\n", __FUNCTION__, rc); 250 251 return rc; 252} 253 254static 255int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref, 256 SSLProtocolVersion protocolVersion) 257{ 258 int socket = (int)ref; 259 int rc; 260 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion)); 261 262 printf("%s: rc=%d\n", __FUNCTION__, rc); 263 264 return rc; 265} 266 267 268static 269int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref) 270{ 271 int socket = (int)ref; 272 int rc; 273 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0); 274 275 return rc; 276} 277 278 279const struct SSLRecordFuncs TLSSocket_Funcs = { 280 .read = TLSSocket_Read, 281 .write = TLSSocket_Write, 282 .initPendingCiphers = TLSSocket_InitPendingCiphers, 283 .advanceWriteCipher = TLSSocket_AdvanceWriteCipher, 284 .rollbackWriteCipher = TLSSocket_RollbackWriteCipher, 285 .advanceReadCipher = TLSSocket_AdvanceReadCipher, 286 .setProtocolVersion = TLSSocket_SetProtocolVersion, 287 .free = TLSSocket_Free, 288 .serviceWriteQueue = TLSSocket_ServiceWriteQueue, 289}; 290 291 292/* TLSSocket SPIs */ 293 294int TLSSocket_Attach(int socket) 295{ 296 297 /* Attach the TLS socket filter and return handle */ 298 struct so_nke so_tlsnke; 299 int rc; 300 int handle; 301 socklen_t len; 302 303 memset(&so_tlsnke, 0, sizeof(so_tlsnke)); 304 so_tlsnke.nke_handle = TLS_HANDLE_IP4; 305 rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke)); 306 if(rc) 307 return rc; 308 309 len = sizeof(handle); 310 rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len); 311 if(rc) 312 return rc; 313 314 assert(len==sizeof(handle)); 315 316 return handle; 317} 318 319