1/*
2 * Copyright (c) 2011-2014 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24
25/* THIS FILE CONTAINS KERNEL CODE */
26
27#include "sslBuildFlags.h"
28#include "SSLRecordInternal.h"
29#include "sslDebug.h"
30#include "cipherSpecs.h"
31#include "symCipher.h"
32#include "sslUtils.h"
33#include "tls_record_internal.h"
34
35#include <AssertMacros.h>
36#include <string.h>
37
38#include <inttypes.h>
39#include <stddef.h>
40
41/* Maximum encrypted record size, defined in TLS 1.2 RFC, section 6.2.3 */
42#define DEFAULT_BUFFER_SIZE (16384 + 2048)
43
44
45/*
46 * Redirect SSLBuffer-based I/O call to user-supplied I/O.
47 */
48static
49int sslIoRead(SSLBuffer                        buf,
50              size_t                           *actualLength,
51              struct SSLRecordInternalContext  *ctx)
52{
53	size_t  dataLength = buf.length;
54	int     ortn;
55
56	*actualLength = 0;
57	ortn = (ctx->read)(ctx->ioRef,
58                       buf.data,
59                       &dataLength);
60	*actualLength = dataLength;
61
62    sslLogRecordIo("sslIoRead: [%p] req %4lu actual %4lu status %d",
63                   ctx, buf.length, dataLength, (int)ortn);
64
65	return ortn;
66}
67
68static
69int sslIoWrite(SSLBuffer                       buf,
70               size_t                          *actualLength,
71               struct SSLRecordInternalContext *ctx)
72{
73	size_t  dataLength = buf.length;
74	int     ortn;
75
76	*actualLength = 0;
77	ortn = (ctx->write)(ctx->ioRef,
78                        buf.data,
79                        &dataLength);
80	*actualLength = dataLength;
81
82    sslLogRecordIo("sslIoWrite: [%p] req %4lu actual %4lu status %d",
83                   ctx, buf.length, dataLength, (int)ortn);
84
85	return ortn;
86}
87
88/* Entry points to Record Layer */
89
90static int SSLRecordReadInternal(SSLRecordContextRef ref, SSLRecord *rec)
91{
92    struct SSLRecordInternalContext *ctx = ref;
93
94    int     err;
95    size_t  len, contentLen;
96    SSLBuffer readData;
97
98    size_t head=tls_record_get_header_size(ctx->filter);
99
100    if (ctx->amountRead < head)
101    {
102        readData.length = head - ctx->amountRead;
103        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
104        len = readData.length;
105        err = sslIoRead(readData, &len, ctx);
106        if(err != 0)
107        {
108            switch(err) {
109                case errSSLRecordWouldBlock:
110                    ctx->amountRead += len;
111                    break;
112                default:
113                    /* Any other error but errSSLWouldBlock is  translated to errSSLRecordClosedAbort */
114                    err = errSSLRecordClosedAbort;
115                    break;
116            }
117            return err;
118        }
119        ctx->amountRead += len;
120
121        check(ctx->amountRead == head);
122    }
123
124
125    tls_buffer header;
126    header.data=ctx->partialReadBuffer.data;
127    header.length=head;
128
129    uint8_t content_type;
130
131    tls_record_parse_header(ctx->filter, header, &contentLen, &content_type);
132
133    if(content_type&0x80) {
134        // Looks like SSL2 record, reset expectations.
135        head = 2;
136        err=tls_record_parse_ssl2_header(ctx->filter, header, &contentLen, &content_type);
137        if(err!=0) return errSSLRecordUnexpectedRecord;
138    }
139
140    check(ctx->partialReadBuffer.length>=head+contentLen);
141
142    if(head+contentLen>ctx->partialReadBuffer.length)
143        return errSSLRecordRecordOverflow;
144
145    if (ctx->amountRead < head + contentLen)
146    {   readData.length = head + contentLen - ctx->amountRead;
147        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
148        len = readData.length;
149        err = sslIoRead(readData, &len, ctx);
150        if(err != 0)
151        {   if (err == errSSLRecordWouldBlock)
152            ctx->amountRead += len;
153            return err;
154        }
155        ctx->amountRead += len;
156    }
157
158    check(ctx->amountRead == head + contentLen);
159
160    tls_buffer record;
161    record.data = ctx->partialReadBuffer.data;
162    record.length = ctx->amountRead;
163
164    rec->contentType = content_type;
165
166    ctx->amountRead = 0;        /* We've used all the data in the cache */
167
168    if(content_type==tls_record_type_SSL2) {
169        /* Just copy the SSL2 record, dont decrypt since this is only for SSL2 Client Hello */
170        return SSLCopyBuffer(&record, &rec->contents);
171    } else {
172        size_t sz = tls_record_decrypted_size(ctx->filter, record.length);
173
174        /* There was an underflow - For TLS, we return errSSLRecordClosedAbort for historical reason - see ssl-44-crashes test */
175        if(sz==0) {
176            sslErrorLog("underflow in SSLReadRecordInternal");
177            if(ctx->dtls) {
178                // For DTLS, we should just drop it.
179                return errSSLRecordUnexpectedRecord;
180            } else {
181                // For TLS, we are going to close the connection.
182                return errSSLRecordClosedAbort;
183            }
184        }
185
186        /* Allocate a buffer for the plaintext */
187        if ((err = SSLAllocBuffer(&rec->contents, sz)))
188        {
189            return err;
190        }
191
192        return tls_record_decrypt(ctx->filter, record, &rec->contents, NULL);
193    }
194}
195
196static int SSLRecordWriteInternal(SSLRecordContextRef ref, SSLRecord rec)
197{
198    int err;
199    struct SSLRecordInternalContext *ctx = ref;
200    WaitingRecord *queue, *out;
201    tls_buffer data;
202    tls_buffer content;
203    size_t len;
204
205    err = errSSLRecordInternal; /* FIXME: allocation error */
206    len=tls_record_encrypted_size(ctx->filter, rec.contentType, rec.contents.length);
207
208    require((out = (WaitingRecord *)sslMalloc(offsetof(WaitingRecord, data) + len)), fail);
209    out->next = NULL;
210	out->sent = 0;
211	out->length = len;
212
213    data.data=&out->data[0];
214    data.length=out->length;
215
216    content.data = rec.contents.data;
217    content.length = rec.contents.length;
218
219    require_noerr((err=tls_record_encrypt(ctx->filter, content, rec.contentType, &data)), fail);
220
221    out->length = data.length; // This should not be needed if tls_record_encrypted_size works properly.
222
223    /* Enqueue the record to be written from the idle loop */
224    if (ctx->recordWriteQueue == 0)
225        ctx->recordWriteQueue = out;
226    else
227    {   queue = ctx->recordWriteQueue;
228        while (queue->next != 0)
229            queue = queue->next;
230        queue->next = out;
231    }
232
233    return 0;
234fail:
235    if(out)
236        sslFree(out);
237    return err;
238}
239
240/* Record Layer Entry Points */
241
242static int
243SSLRollbackInternalRecordLayerWriteCipher(SSLRecordContextRef ref)
244{
245    struct SSLRecordInternalContext *ctx = ref;
246    return tls_record_rollback_write_cipher(ctx->filter);
247}
248
249static int
250SSLAdvanceInternalRecordLayerWriteCipher(SSLRecordContextRef ref)
251{
252    struct SSLRecordInternalContext *ctx = ref;
253    return tls_record_advance_write_cipher(ctx->filter);
254}
255
256static int
257SSLAdvanceInternalRecordLayerReadCipher(SSLRecordContextRef ref)
258{
259    struct SSLRecordInternalContext *ctx = ref;
260    return tls_record_advance_read_cipher(ctx->filter);
261}
262
263static int
264SSLInitInternalRecordLayerPendingCiphers(SSLRecordContextRef ref, uint16_t selectedCipher, bool isServer, SSLBuffer key)
265{
266    struct SSLRecordInternalContext *ctx = ref;
267    return tls_record_init_pending_ciphers(ctx->filter, selectedCipher, isServer, key);
268}
269
270static int
271SSLSetInternalRecordLayerProtocolVersion(SSLRecordContextRef ref, SSLProtocolVersion negVersion)
272{
273    struct SSLRecordInternalContext *ctx = ref;
274    return tls_record_set_protocol_version(ctx->filter, negVersion);
275}
276
277static int
278SSLRecordFreeInternal(SSLRecordContextRef ref, SSLRecord rec)
279{
280    return SSLFreeBuffer(&rec.contents);
281}
282
283static int
284SSLRecordServiceWriteQueueInternal(SSLRecordContextRef ref)
285{
286    int             err = 0, werr = 0;
287    size_t          written = 0;
288    SSLBuffer       buf;
289    WaitingRecord   *rec;
290    struct SSLRecordInternalContext *ctx= ref;
291
292    while (!werr && ((rec = ctx->recordWriteQueue) != 0))
293    {   buf.data = rec->data + rec->sent;
294        buf.length = rec->length - rec->sent;
295        werr = sslIoWrite(buf, &written, ctx);
296        rec->sent += written;
297        if (rec->sent >= rec->length)
298        {
299            check(rec->sent == rec->length);
300            check(err == 0);
301            ctx->recordWriteQueue = rec->next;
302			sslFree(rec);
303        }
304        if (err) {
305            check_noerr(err);
306            return err;
307        }
308    }
309
310    return werr;
311}
312
313static int
314SSLRecordSetOption(SSLRecordContextRef ref, SSLRecordOption option, bool value)
315{
316    struct SSLRecordInternalContext *ctx = (struct SSLRecordInternalContext *)ref;
317    switch (option) {
318        case kSSLRecordOptionSendOneByteRecord:
319            return tls_record_set_record_splitting(ctx->filter, value);
320            break;
321        default:
322            return 0;
323            break;
324    }
325}
326
327/***** Internal Record Layer APIs *****/
328
329#include <CommonCrypto/CommonRandomSPI.h>
330#define CCRNGSTATE ccDRBGGetRngState()
331
332SSLRecordContextRef
333SSLCreateInternalRecordLayer(bool dtls)
334{
335    struct SSLRecordInternalContext *ctx;
336
337    ctx = sslMalloc(sizeof(struct SSLRecordInternalContext));
338    if(ctx==NULL)
339        return NULL;
340
341    memset(ctx, 0, sizeof(struct SSLRecordInternalContext));
342
343    ctx->dtls = dtls;
344    require((ctx->filter=tls_record_create(dtls, CCRNGSTATE)), fail);
345    require_noerr(SSLAllocBuffer(&ctx->partialReadBuffer,
346                                 DEFAULT_BUFFER_SIZE), fail);
347
348    return ctx;
349
350fail:
351    if(ctx->filter)
352        tls_record_destroy(ctx->filter);
353    sslFree(ctx);
354    return NULL;
355}
356
357int
358SSLSetInternalRecordLayerIOFuncs(
359                                 SSLRecordContextRef ref,
360                                 SSLIOReadFunc    readFunc,
361                                 SSLIOWriteFunc   writeFunc)
362{
363    struct SSLRecordInternalContext *ctx = ref;
364
365    ctx->read = readFunc;
366    ctx->write = writeFunc;
367
368    return 0;
369}
370
371int
372SSLSetInternalRecordLayerConnection(
373                                    SSLRecordContextRef ref,
374                                    SSLIOConnectionRef ioRef)
375{
376    struct SSLRecordInternalContext *ctx = ref;
377
378    ctx->ioRef = ioRef;
379
380    return 0;
381}
382
383void
384SSLDestroyInternalRecordLayer(SSLRecordContextRef ref)
385{
386    struct SSLRecordInternalContext *ctx = ref;
387	WaitingRecord   *waitRecord, *next;
388
389    /* RecordContext cleanup : */
390    SSLFreeBuffer(&ctx->partialReadBuffer);
391    waitRecord = ctx->recordWriteQueue;
392    while (waitRecord)
393    {   next = waitRecord->next;
394        sslFree(waitRecord);
395        waitRecord = next;
396    }
397
398    if(ctx->filter)
399        tls_record_destroy(ctx->filter);
400
401    sslFree(ctx);
402
403}
404
405struct SSLRecordFuncs SSLRecordLayerInternal =
406{
407    .read  = SSLRecordReadInternal,
408    .write = SSLRecordWriteInternal,
409    .initPendingCiphers = SSLInitInternalRecordLayerPendingCiphers,
410    .advanceWriteCipher = SSLAdvanceInternalRecordLayerWriteCipher,
411    .advanceReadCipher = SSLAdvanceInternalRecordLayerReadCipher,
412    .rollbackWriteCipher = SSLRollbackInternalRecordLayerWriteCipher,
413    .setProtocolVersion = SSLSetInternalRecordLayerProtocolVersion,
414    .free = SSLRecordFreeInternal,
415    .serviceWriteQueue = SSLRecordServiceWriteQueueInternal,
416    .setOption = SSLRecordSetOption,
417};
418
419