1/*	$NetBSD: cipher-ctr-mt.c,v 1.10 2019/01/27 02:08:33 pgoyette Exp $	*/
2/*
3 * OpenSSH Multi-threaded AES-CTR Cipher
4 *
5 * Author: Benjamin Bennett <ben@psc.edu>
6 * Copyright (c) 2008 Pittsburgh Supercomputing Center. All rights reserved.
7 *
8 * Based on original OpenSSH AES-CTR cipher. Small portions remain unchanged,
9 * Copyright (c) 2003 Markus Friedl <markus@openbsd.org>
10 *
11 * Permission to use, copy, modify, and distribute this software for any
12 * purpose with or without fee is hereby granted, provided that the above
13 * copyright notice and this permission notice appear in all copies.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
16 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
17 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
18 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
19 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
20 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
21 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
22 */
23#include "includes.h"
24__RCSID("$NetBSD: cipher-ctr-mt.c,v 1.10 2019/01/27 02:08:33 pgoyette Exp $");
25
26#include <sys/types.h>
27
28#include <stdarg.h>
29#include <string.h>
30
31#include <openssl/evp.h>
32
33#include "xmalloc.h"
34#include "log.h"
35
36#ifndef USE_BUILTIN_RIJNDAEL
37#include <openssl/aes.h>
38#endif
39
40#include <pthread.h>
41
42/*-------------------- TUNABLES --------------------*/
43/* Number of pregen threads to use */
44#define CIPHER_THREADS	2
45
46/* Number of keystream queues */
47#define NUMKQ		(CIPHER_THREADS + 2)
48
49/* Length of a keystream queue */
50#define KQLEN		4096
51
52/* Processor cacheline length */
53#define CACHELINE_LEN	64
54
55/* Collect thread stats and print at cancellation when in debug mode */
56/* #define CIPHER_THREAD_STATS */
57
58/* Use single-byte XOR instead of 8-byte XOR */
59/* #define CIPHER_BYTE_XOR */
60/*-------------------- END TUNABLES --------------------*/
61
62#ifdef AES_CTR_MT
63
64
65const EVP_CIPHER *evp_aes_ctr_mt(void);
66
67#ifdef CIPHER_THREAD_STATS
68/*
69 * Struct to collect thread stats
70 */
71struct thread_stats {
72	u_int	fills;
73	u_int	skips;
74	u_int	waits;
75	u_int	drains;
76};
77
78/*
79 * Debug print the thread stats
80 * Use with pthread_cleanup_push for displaying at thread cancellation
81 */
82static void
83thread_loop_stats(void *x)
84{
85	struct thread_stats *s = x;
86
87	debug("tid %lu - %u fills, %u skips, %u waits", pthread_self(),
88			s->fills, s->skips, s->waits);
89}
90
91 #define STATS_STRUCT(s)	struct thread_stats s
92 #define STATS_INIT(s)		{ memset(&s, 0, sizeof(s)); }
93 #define STATS_FILL(s)		{ s.fills++; }
94 #define STATS_SKIP(s)		{ s.skips++; }
95 #define STATS_WAIT(s)		{ s.waits++; }
96 #define STATS_DRAIN(s)		{ s.drains++; }
97#else
98 #define STATS_STRUCT(s)
99 #define STATS_INIT(s)
100 #define STATS_FILL(s)
101 #define STATS_SKIP(s)
102 #define STATS_WAIT(s)
103 #define STATS_DRAIN(s)
104#endif
105
106/* Keystream Queue state */
107enum {
108	KQINIT,
109	KQEMPTY,
110	KQFILLING,
111	KQFULL,
112	KQDRAINING
113};
114
115/* Keystream Queue struct */
116struct kq {
117	u_char		keys[KQLEN][AES_BLOCK_SIZE];
118	u_char		ctr[AES_BLOCK_SIZE];
119	u_char		pad0[CACHELINE_LEN];
120	volatile int	qstate;
121	pthread_mutex_t	lock;
122	pthread_cond_t	cond;
123	u_char		pad1[CACHELINE_LEN];
124};
125
126/* Context struct */
127struct ssh_aes_ctr_ctx
128{
129	struct kq	q[NUMKQ];
130	AES_KEY		aes_ctx;
131	STATS_STRUCT(stats);
132	u_char		aes_counter[AES_BLOCK_SIZE];
133	pthread_t	tid[CIPHER_THREADS];
134	int		state;
135	int		qidx;
136	int		ridx;
137};
138
139/* <friedl>
140 * increment counter 'ctr',
141 * the counter is of size 'len' bytes and stored in network-byte-order.
142 * (LSB at ctr[len-1], MSB at ctr[0])
143 */
144static void
145ssh_ctr_inc(u_char *ctr, u_int len)
146{
147	int i;
148
149	for (i = len - 1; i >= 0; i--)
150		if (++ctr[i])	/* continue on overflow */
151			return;
152}
153
154/*
155 * Add num to counter 'ctr'
156 */
157static void
158ssh_ctr_add(u_char *ctr, uint32_t num, u_int len)
159{
160	int i;
161	uint16_t n;
162
163	for (n = 0, i = len - 1; i >= 0 && (num || n); i--) {
164		n = ctr[i] + (num & 0xff) + n;
165		num >>= 8;
166		ctr[i] = n & 0xff;
167		n >>= 8;
168	}
169}
170
171/*
172 * Threads may be cancelled in a pthread_cond_wait, we must free the mutex
173 */
174static void
175thread_loop_cleanup(void *x)
176{
177	pthread_mutex_unlock((pthread_mutex_t *)x);
178}
179
180/*
181 * The life of a pregen thread:
182 *    Find empty keystream queues and fill them using their counter.
183 *    When done, update counter for the next fill.
184 */
185static void *
186thread_loop(void *x)
187{
188	AES_KEY key;
189	STATS_STRUCT(stats);
190	struct ssh_aes_ctr_ctx *c = x;
191	struct kq *q;
192	int i;
193	int qidx;
194
195	/* Threads stats on cancellation */
196	STATS_INIT(stats);
197#ifdef CIPHER_THREAD_STATS
198	pthread_cleanup_push(thread_loop_stats, &stats);
199#endif
200
201	/* Thread local copy of AES key */
202	memcpy(&key, &c->aes_ctx, sizeof(key));
203
204	/*
205	 * Handle the special case of startup, one thread must fill
206 	 * the first KQ then mark it as draining. Lock held throughout.
207 	 */
208	if (pthread_equal(pthread_self(), c->tid[0])) {
209		q = &c->q[0];
210		pthread_mutex_lock(&q->lock);
211		if (q->qstate == KQINIT) {
212			for (i = 0; i < KQLEN; i++) {
213				AES_encrypt(q->ctr, q->keys[i], &key);
214				ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
215			}
216			ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
217			q->qstate = KQDRAINING;
218			STATS_FILL(stats);
219			pthread_cond_broadcast(&q->cond);
220		}
221		pthread_mutex_unlock(&q->lock);
222	}
223	else
224		STATS_SKIP(stats);
225
226	/*
227 	 * Normal case is to find empty queues and fill them, skipping over
228 	 * queues already filled by other threads and stopping to wait for
229 	 * a draining queue to become empty.
230 	 *
231 	 * Multiple threads may be waiting on a draining queue and awoken
232 	 * when empty.  The first thread to wake will mark it as filling,
233 	 * others will move on to fill, skip, or wait on the next queue.
234 	 */
235	for (qidx = 1;; qidx = (qidx + 1) % NUMKQ) {
236		/* Check if I was cancelled, also checked in cond_wait */
237		pthread_testcancel();
238
239		/* Lock queue and block if its draining */
240		q = &c->q[qidx];
241		pthread_mutex_lock(&q->lock);
242		pthread_cleanup_push(thread_loop_cleanup, &q->lock);
243		while (q->qstate == KQDRAINING || q->qstate == KQINIT) {
244			STATS_WAIT(stats);
245			pthread_cond_wait(&q->cond, &q->lock);
246		}
247		pthread_cleanup_pop(0);
248
249		/* If filling or full, somebody else got it, skip */
250		if (q->qstate != KQEMPTY) {
251			pthread_mutex_unlock(&q->lock);
252			STATS_SKIP(stats);
253			continue;
254		}
255
256		/*
257 		 * Empty, let's fill it.
258 		 * Queue lock is relinquished while we do this so others
259 		 * can see that it's being filled.
260 		 */
261		q->qstate = KQFILLING;
262		pthread_mutex_unlock(&q->lock);
263		for (i = 0; i < KQLEN; i++) {
264			AES_encrypt(q->ctr, q->keys[i], &key);
265			ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
266		}
267
268		/* Re-lock, mark full and signal consumer */
269		pthread_mutex_lock(&q->lock);
270		ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
271		q->qstate = KQFULL;
272		STATS_FILL(stats);
273		pthread_cond_signal(&q->cond);
274		pthread_mutex_unlock(&q->lock);
275	}
276
277#ifdef CIPHER_THREAD_STATS
278	/* Stats */
279	pthread_cleanup_pop(1);
280#endif
281
282	return NULL;
283}
284
285static int
286ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src,
287    u_int len)
288{
289	struct ssh_aes_ctr_ctx *c;
290	struct kq *q, *oldq;
291	int ridx;
292	u_char *buf;
293
294	if (len == 0)
295		return (1);
296	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL)
297		return (0);
298
299	q = &c->q[c->qidx];
300	ridx = c->ridx;
301
302	/* src already padded to block multiple */
303	while (len > 0) {
304		buf = q->keys[ridx];
305
306#ifdef CIPHER_BYTE_XOR
307		dest[0] = src[0] ^ buf[0];
308		dest[1] = src[1] ^ buf[1];
309		dest[2] = src[2] ^ buf[2];
310		dest[3] = src[3] ^ buf[3];
311		dest[4] = src[4] ^ buf[4];
312		dest[5] = src[5] ^ buf[5];
313		dest[6] = src[6] ^ buf[6];
314		dest[7] = src[7] ^ buf[7];
315		dest[8] = src[8] ^ buf[8];
316		dest[9] = src[9] ^ buf[9];
317		dest[10] = src[10] ^ buf[10];
318		dest[11] = src[11] ^ buf[11];
319		dest[12] = src[12] ^ buf[12];
320		dest[13] = src[13] ^ buf[13];
321		dest[14] = src[14] ^ buf[14];
322		dest[15] = src[15] ^ buf[15];
323#else
324		*(uint64_t *)dest = *(uint64_t *)src ^ *(uint64_t *)buf;
325		*(uint64_t *)(dest + 8) = *(uint64_t *)(src + 8) ^
326						*(uint64_t *)(buf + 8);
327#endif
328
329		dest += 16;
330		src += 16;
331		len -= 16;
332		ssh_ctr_inc(ctx->iv, AES_BLOCK_SIZE);
333
334		/* Increment read index, switch queues on rollover */
335		if ((ridx = (ridx + 1) % KQLEN) == 0) {
336			oldq = q;
337
338			/* Mark next queue draining, may need to wait */
339			c->qidx = (c->qidx + 1) % NUMKQ;
340			q = &c->q[c->qidx];
341			pthread_mutex_lock(&q->lock);
342			while (q->qstate != KQFULL) {
343				STATS_WAIT(c->stats);
344				pthread_cond_wait(&q->cond, &q->lock);
345			}
346			q->qstate = KQDRAINING;
347			pthread_mutex_unlock(&q->lock);
348
349			/* Mark consumed queue empty and signal producers */
350			pthread_mutex_lock(&oldq->lock);
351			oldq->qstate = KQEMPTY;
352			STATS_DRAIN(c->stats);
353			pthread_cond_broadcast(&oldq->cond);
354			pthread_mutex_unlock(&oldq->lock);
355		}
356	}
357	c->ridx = ridx;
358	return (1);
359}
360
361#define HAVE_NONE       0
362#define HAVE_KEY        1
363#define HAVE_IV         2
364static int
365ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv,
366    int enc)
367{
368	struct ssh_aes_ctr_ctx *c;
369	int i;
370
371	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) {
372		c = xmalloc(sizeof(*c));
373
374		c->state = HAVE_NONE;
375		for (i = 0; i < NUMKQ; i++) {
376			pthread_mutex_init(&c->q[i].lock, NULL);
377			pthread_cond_init(&c->q[i].cond, NULL);
378		}
379
380		STATS_INIT(c->stats);
381
382		EVP_CIPHER_CTX_set_app_data(ctx, c);
383	}
384
385	if (c->state == (HAVE_KEY | HAVE_IV)) {
386		/* Cancel pregen threads */
387		for (i = 0; i < CIPHER_THREADS; i++)
388			pthread_cancel(c->tid[i]);
389		for (i = 0; i < CIPHER_THREADS; i++)
390			pthread_join(c->tid[i], NULL);
391		/* Start over getting key & iv */
392		c->state = HAVE_NONE;
393	}
394
395	if (key != NULL) {
396		AES_set_encrypt_key(key, EVP_CIPHER_CTX_key_length(ctx) * 8,
397		    &c->aes_ctx);
398		c->state |= HAVE_KEY;
399	}
400
401	if (iv != NULL) {
402		memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
403		c->state |= HAVE_IV;
404	}
405
406	if (c->state == (HAVE_KEY | HAVE_IV)) {
407		/* Clear queues */
408		memcpy(c->q[0].ctr, ctx->iv, AES_BLOCK_SIZE);
409		c->q[0].qstate = KQINIT;
410		for (i = 1; i < NUMKQ; i++) {
411			memcpy(c->q[i].ctr, ctx->iv, AES_BLOCK_SIZE);
412			ssh_ctr_add(c->q[i].ctr, i * KQLEN, AES_BLOCK_SIZE);
413			c->q[i].qstate = KQEMPTY;
414		}
415		c->qidx = 0;
416		c->ridx = 0;
417
418		/* Start threads */
419		for (i = 0; i < CIPHER_THREADS; i++) {
420			pthread_create(&c->tid[i], NULL, thread_loop, c);
421		}
422		pthread_mutex_lock(&c->q[0].lock);
423		while (c->q[0].qstate != KQDRAINING)
424			pthread_cond_wait(&c->q[0].cond, &c->q[0].lock);
425		pthread_mutex_unlock(&c->q[0].lock);
426
427	}
428	return (1);
429}
430
431static int
432ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx)
433{
434	struct ssh_aes_ctr_ctx *c;
435	int i;
436
437	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) {
438#ifdef CIPHER_THREAD_STATS
439		debug("main thread: %u drains, %u waits", c->stats.drains,
440				c->stats.waits);
441#endif
442		/* Cancel pregen threads */
443		for (i = 0; i < CIPHER_THREADS; i++)
444			pthread_cancel(c->tid[i]);
445		for (i = 0; i < CIPHER_THREADS; i++)
446			pthread_join(c->tid[i], NULL);
447
448		memset(c, 0, sizeof(*c));
449		free(c);
450		EVP_CIPHER_CTX_set_app_data(ctx, NULL);
451	}
452	return (1);
453}
454
455/* <friedl> */
456const EVP_CIPHER *
457evp_aes_ctr_mt(void)
458{
459	static EVP_CIPHER aes_ctr;
460
461	memset(&aes_ctr, 0, sizeof(EVP_CIPHER));
462	aes_ctr.nid = NID_undef;
463	aes_ctr.block_size = AES_BLOCK_SIZE;
464	aes_ctr.iv_len = AES_BLOCK_SIZE;
465	aes_ctr.key_len = 16;
466	aes_ctr.init = ssh_aes_ctr_init;
467	aes_ctr.cleanup = ssh_aes_ctr_cleanup;
468	aes_ctr.do_cipher = ssh_aes_ctr;
469#ifndef SSH_OLD_EVP
470	aes_ctr.flags = EVP_CIPH_CBC_MODE | EVP_CIPH_VARIABLE_LENGTH |
471	    EVP_CIPH_ALWAYS_CALL_INIT | EVP_CIPH_CUSTOM_IV;
472#endif
473	return (&aes_ctr);
474}
475#endif /* AES_CTR_MT */
476