1/*	$OpenBSD: aeadtest.c,v 1.26 2023/09/28 14:55:48 tb Exp $	*/
2/*
3 * Copyright (c) 2022 Joel Sing <jsing@openbsd.org>
4 * Copyright (c) 2014, Google Inc.
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <ctype.h>
20#include <stdint.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24#include <unistd.h>
25
26#include <openssl/err.h>
27#include <openssl/evp.h>
28
29/*
30 * This program tests an AEAD against a series of test vectors from a file. The
31 * test vector file consists of key-value lines where the key and value are
32 * separated by a colon and optional whitespace. The keys are listed in
33 * NAMES, below. The values are hex-encoded data.
34 *
35 * After a number of key-value lines, a blank line indicates the end of the
36 * test case.
37 *
38 * For example, here's a valid test case:
39 *
40 *   AEAD: chacha20-poly1305
41 *   KEY: bcb2639bf989c6251b29bf38d39a9bdce7c55f4b2ac12a39c8a37b5d0a5cc2b5
42 *   NONCE: 1e8b4c510f5ca083
43 *   IN: 8c8419bc27
44 *   AD: 34ab88c265
45 *   CT: 1a7c2f33f5
46 *   TAG: 2875c659d0f2808de3a40027feff91a4
47 */
48
49#define BUF_MAX 1024
50
51/* MS defines in global headers, remove it */
52#ifdef _MSC_VER
53#ifdef IN
54#undef IN
55#endif
56#endif
57
58/* These are the different types of line that are found in the input file. */
59enum {
60	AEAD = 0,	/* name of the AEAD algorithm. */
61	KEY,		/* hex encoded key. */
62	NONCE,		/* hex encoded nonce. */
63	IN,		/* hex encoded plaintext. */
64	AD,		/* hex encoded additional data. */
65	CT,		/* hex encoded ciphertext (not including the
66			 * authenticator, which is next. */
67	TAG,		/* hex encoded authenticator. */
68	NUM_TYPES
69};
70
71static const char NAMES[NUM_TYPES][6] = {
72	"AEAD",
73	"KEY",
74	"NONCE",
75	"IN",
76	"AD",
77	"CT",
78	"TAG",
79};
80
81static unsigned char
82hex_digit(char h)
83{
84	if (h >= '0' && h <= '9')
85		return h - '0';
86	else if (h >= 'a' && h <= 'f')
87		return h - 'a' + 10;
88	else if (h >= 'A' && h <= 'F')
89		return h - 'A' + 10;
90	else
91		return 16;
92}
93
94static int
95aead_from_name(const EVP_AEAD **aead, const EVP_CIPHER **cipher,
96    const char *name)
97{
98	*aead = NULL;
99	*cipher = NULL;
100
101	if (strcmp(name, "aes-128-gcm") == 0) {
102		*aead = EVP_aead_aes_128_gcm();
103		*cipher = EVP_aes_128_gcm();
104	} else if (strcmp(name, "aes-192-gcm") == 0) {
105		*cipher = EVP_aes_192_gcm();
106	} else if (strcmp(name, "aes-256-gcm") == 0) {
107		*aead = EVP_aead_aes_256_gcm();
108		*cipher = EVP_aes_256_gcm();
109	} else if (strcmp(name, "chacha20-poly1305") == 0) {
110		*aead = EVP_aead_chacha20_poly1305();
111		*cipher = EVP_chacha20_poly1305();
112	} else if (strcmp(name, "xchacha20-poly1305") == 0) {
113		*aead = EVP_aead_xchacha20_poly1305();
114	} else {
115		fprintf(stderr, "Unknown AEAD: %s\n", name);
116		return 0;
117	}
118
119	return 1;
120}
121
122static int
123run_aead_test(const EVP_AEAD *aead, unsigned char bufs[NUM_TYPES][BUF_MAX],
124    const unsigned int lengths[NUM_TYPES], unsigned int line_no)
125{
126	EVP_AEAD_CTX *ctx;
127	unsigned char out[BUF_MAX + EVP_AEAD_MAX_TAG_LENGTH], out2[BUF_MAX];
128	size_t out_len, out_len2;
129	int ret = 0;
130
131	if ((ctx = EVP_AEAD_CTX_new()) == NULL) {
132		fprintf(stderr, "Failed to allocate AEAD context on line %u\n",
133		    line_no);
134		goto err;
135	}
136
137	if (!EVP_AEAD_CTX_init(ctx, aead, bufs[KEY], lengths[KEY],
138	    lengths[TAG], NULL)) {
139		fprintf(stderr, "Failed to init AEAD on line %u\n", line_no);
140		goto err;
141	}
142
143	if (!EVP_AEAD_CTX_seal(ctx, out, &out_len, sizeof(out), bufs[NONCE],
144	    lengths[NONCE], bufs[IN], lengths[IN], bufs[AD], lengths[AD])) {
145		fprintf(stderr, "Failed to run AEAD on line %u\n", line_no);
146		goto err;
147	}
148
149	if (out_len != lengths[CT] + lengths[TAG]) {
150		fprintf(stderr, "Bad output length on line %u: %zu vs %u\n",
151		    line_no, out_len, (unsigned)(lengths[CT] + lengths[TAG]));
152		goto err;
153	}
154
155	if (memcmp(out, bufs[CT], lengths[CT]) != 0) {
156		fprintf(stderr, "Bad output on line %u\n", line_no);
157		goto err;
158	}
159
160	if (memcmp(out + lengths[CT], bufs[TAG], lengths[TAG]) != 0) {
161		fprintf(stderr, "Bad tag on line %u\n", line_no);
162		goto err;
163	}
164
165	if (!EVP_AEAD_CTX_open(ctx, out2, &out_len2, lengths[IN], bufs[NONCE],
166	    lengths[NONCE], out, out_len, bufs[AD], lengths[AD])) {
167		fprintf(stderr, "Failed to decrypt on line %u\n", line_no);
168		goto err;
169	}
170
171	if (out_len2 != lengths[IN]) {
172		fprintf(stderr, "Bad decrypt on line %u: %zu\n",
173		    line_no, out_len2);
174		goto err;
175	}
176
177	if (memcmp(out2, bufs[IN], out_len2) != 0) {
178		fprintf(stderr, "Plaintext mismatch on line %u\n", line_no);
179		goto err;
180	}
181
182	out[0] ^= 0x80;
183	if (EVP_AEAD_CTX_open(ctx, out2, &out_len2, lengths[IN], bufs[NONCE],
184	    lengths[NONCE], out, out_len, bufs[AD], lengths[AD])) {
185		fprintf(stderr, "Decrypted bad data on line %u\n", line_no);
186		goto err;
187	}
188
189	ret = 1;
190
191 err:
192	EVP_AEAD_CTX_free(ctx);
193
194	return ret;
195}
196
197static int
198run_cipher_aead_encrypt_test(const EVP_CIPHER *cipher,
199    unsigned char bufs[NUM_TYPES][BUF_MAX],
200    const unsigned int lengths[NUM_TYPES], unsigned int line_no)
201{
202	unsigned char out[BUF_MAX + EVP_AEAD_MAX_TAG_LENGTH];
203	EVP_CIPHER_CTX *ctx;
204	size_t out_len;
205	int len;
206	int ivlen;
207	int ret = 0;
208
209	if ((ctx = EVP_CIPHER_CTX_new()) == NULL) {
210		fprintf(stderr, "FAIL: EVP_CIPHER_CTX_new\n");
211		goto err;
212	}
213
214	if (!EVP_EncryptInit_ex(ctx, cipher, NULL, NULL, NULL)) {
215		fprintf(stderr, "FAIL: EVP_EncryptInit_ex with cipher\n");
216		goto err;
217	}
218
219	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, lengths[NONCE], NULL)) {
220		fprintf(stderr, "FAIL: EVP_CTRL_AEAD_SET_IVLEN\n");
221		goto err;
222	}
223
224	ivlen = EVP_CIPHER_CTX_iv_length(ctx);
225	if (ivlen != (int)lengths[NONCE]) {
226		fprintf(stderr, "FAIL: ivlen %d != nonce length %d\n", ivlen,
227		    (int)lengths[NONCE]);
228		goto err;
229	}
230
231	if (!EVP_EncryptInit_ex(ctx, NULL, NULL, bufs[KEY], NULL)) {
232		fprintf(stderr, "FAIL: EVP_EncryptInit_ex with key\n");
233		goto err;
234	}
235	if (!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, bufs[NONCE])) {
236		fprintf(stderr, "FAIL: EVP_EncryptInit_ex with nonce\n");
237		goto err;
238	}
239
240	if (!EVP_EncryptUpdate(ctx, NULL, &len, bufs[AD], lengths[AD])) {
241		fprintf(stderr, "FAIL: EVP_EncryptUpdate with AD\n");
242		goto err;
243	}
244	if ((unsigned int)len != lengths[AD]) {
245		fprintf(stderr, "FAIL: EVP_EncryptUpdate with AD length = %u, "
246		    "want %u\n", len, lengths[AD]);
247		goto err;
248	}
249	if (!EVP_EncryptUpdate(ctx, out, &len, bufs[IN], lengths[IN])) {
250		fprintf(stderr, "FAIL: EVP_EncryptUpdate with plaintext\n");
251		goto err;
252	}
253	out_len = len;
254	if (!EVP_EncryptFinal_ex(ctx, out + out_len, &len)) {
255		fprintf(stderr, "FAIL: EVP_EncryptFinal_ex\n");
256		goto err;
257	}
258	out_len += len;
259	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, lengths[TAG],
260	    out + out_len)) {
261		fprintf(stderr, "FAIL: EVP_EncryptInit_ex with cipher\n");
262		goto err;
263	}
264	out_len += lengths[TAG];
265
266	if (out_len != lengths[CT] + lengths[TAG]) {
267		fprintf(stderr, "Bad output length on line %u: %zu vs %u\n",
268		    line_no, out_len, (unsigned)(lengths[CT] + lengths[TAG]));
269		goto err;
270	}
271
272	if (memcmp(out, bufs[CT], lengths[CT]) != 0) {
273		fprintf(stderr, "Bad output on line %u\n", line_no);
274		goto err;
275	}
276
277	if (memcmp(out + lengths[CT], bufs[TAG], lengths[TAG]) != 0) {
278		fprintf(stderr, "Bad tag on line %u\n", line_no);
279		goto err;
280	}
281
282	ret = 1;
283
284 err:
285	EVP_CIPHER_CTX_free(ctx);
286
287	return ret;
288}
289
290static int
291run_cipher_aead_decrypt_test(const EVP_CIPHER *cipher, int invalid,
292    unsigned char bufs[NUM_TYPES][BUF_MAX],
293    const unsigned int lengths[NUM_TYPES], unsigned int line_no)
294{
295	unsigned char in[BUF_MAX], out[BUF_MAX + EVP_AEAD_MAX_TAG_LENGTH];
296	EVP_CIPHER_CTX *ctx;
297	size_t out_len;
298	int len;
299	int ret = 0;
300
301	if ((ctx = EVP_CIPHER_CTX_new()) == NULL) {
302		fprintf(stderr, "FAIL: EVP_CIPHER_CTX_new\n");
303		goto err;
304	}
305
306	if (!EVP_DecryptInit_ex(ctx, cipher, NULL, NULL, NULL)) {
307		fprintf(stderr, "FAIL: EVP_DecryptInit_ex with cipher\n");
308		goto err;
309	}
310
311	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, lengths[NONCE],
312	    NULL)) {
313		fprintf(stderr, "FAIL: EVP_CTRL_AEAD_SET_IVLEN\n");
314		goto err;
315	}
316
317	memcpy(in, bufs[TAG], lengths[TAG]);
318	if (invalid && lengths[CT] == 0)
319		in[0] ^= 0x80;
320
321	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, lengths[TAG], in)) {
322		fprintf(stderr, "FAIL: EVP_CTRL_AEAD_SET_TAG\n");
323		goto err;
324	}
325
326	if (!EVP_DecryptInit_ex(ctx, NULL, NULL, bufs[KEY], NULL)) {
327		fprintf(stderr, "FAIL: EVP_DecryptInit_ex with key\n");
328		goto err;
329	}
330	if (!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, bufs[NONCE])) {
331		fprintf(stderr, "FAIL: EVP_DecryptInit_ex with nonce\n");
332		goto err;
333	}
334
335	if (!EVP_DecryptUpdate(ctx, NULL, &len, bufs[AD], lengths[AD])) {
336		fprintf(stderr, "FAIL: EVP_DecryptUpdate with AD\n");
337		goto err;
338	}
339	if ((unsigned int)len != lengths[AD]) {
340		fprintf(stderr, "FAIL: EVP_EncryptUpdate with AD length = %u, "
341		    "want %u\n", len, lengths[AD]);
342		goto err;
343	}
344
345	memcpy(in, bufs[CT], lengths[CT]);
346	if (invalid && lengths[CT] > 0)
347		in[0] ^= 0x80;
348
349	if (!EVP_DecryptUpdate(ctx, out, &len, in, lengths[CT])) {
350		fprintf(stderr, "FAIL: EVP_DecryptUpdate with ciphertext\n");
351		goto err;
352	}
353	out_len = len;
354
355	if (invalid) {
356		if (EVP_DecryptFinal_ex(ctx, out + out_len, &len)) {
357			fprintf(stderr, "FAIL: EVP_DecryptFinal_ex succeeded "
358			    "with invalid ciphertext on line %u\n", line_no);
359			goto err;
360		}
361		goto done;
362	}
363
364	if (!EVP_DecryptFinal_ex(ctx, out + out_len, &len)) {
365		fprintf(stderr, "FAIL: EVP_DecryptFinal_ex\n");
366		goto err;
367	}
368	out_len += len;
369
370	if (out_len != lengths[IN]) {
371		fprintf(stderr, "Bad decrypt on line %u: %zu\n",
372		    line_no, out_len);
373		goto err;
374	}
375
376	if (memcmp(out, bufs[IN], out_len) != 0) {
377		fprintf(stderr, "Plaintext mismatch on line %u\n", line_no);
378		goto err;
379	}
380
381 done:
382	ret = 1;
383
384 err:
385	EVP_CIPHER_CTX_free(ctx);
386
387	return ret;
388}
389
390static int
391run_cipher_aead_test(const EVP_CIPHER *cipher,
392    unsigned char bufs[NUM_TYPES][BUF_MAX],
393    const unsigned int lengths[NUM_TYPES], unsigned int line_no)
394{
395	if (!run_cipher_aead_encrypt_test(cipher, bufs, lengths, line_no))
396		return 0;
397	if (!run_cipher_aead_decrypt_test(cipher, 0, bufs, lengths, line_no))
398		return 0;
399	if (!run_cipher_aead_decrypt_test(cipher, 1, bufs, lengths, line_no))
400		return 0;
401
402	return 1;
403}
404
405int
406main(int argc, char **argv)
407{
408	FILE *f;
409	const EVP_AEAD *aead = NULL;
410	const EVP_CIPHER *cipher = NULL;
411	unsigned int line_no = 0, num_tests = 0, j;
412	unsigned char bufs[NUM_TYPES][BUF_MAX];
413	unsigned int lengths[NUM_TYPES];
414	const char *aeadname;
415
416	if (argc != 3) {
417		fprintf(stderr, "%s <aead> <test file.txt>\n", argv[0]);
418		return 1;
419	}
420
421	if ((f = fopen(argv[2], "r")) == NULL) {
422		perror("failed to open input");
423		return 1;
424	}
425
426	for (j = 0; j < NUM_TYPES; j++)
427		lengths[j] = 0;
428
429	for (;;) {
430		char line[4096];
431		unsigned int i, type_len = 0;
432
433		unsigned char *buf = NULL;
434		unsigned int *buf_len = NULL;
435
436		if (!fgets(line, sizeof(line), f))
437			break;
438
439		line_no++;
440		if (line[0] == '#')
441			continue;
442
443		if (line[0] == '\n' || line[0] == 0) {
444			/* Run a test, if possible. */
445			char any_values_set = 0;
446			for (j = 0; j < NUM_TYPES; j++) {
447				if (lengths[j] != 0) {
448					any_values_set = 1;
449					break;
450				}
451			}
452
453			if (!any_values_set)
454				continue;
455
456			aeadname = argv[1];
457			if (lengths[AEAD] != 0)
458				aeadname = bufs[AEAD];
459
460			if (!aead_from_name(&aead, &cipher, aeadname)) {
461				fprintf(stderr, "Aborting...\n");
462				return 4;
463			}
464
465			if (aead != NULL) {
466				if (!run_aead_test(aead, bufs, lengths,
467				    line_no))
468					return 4;
469			}
470			if (cipher != NULL) {
471				if (!run_cipher_aead_test(cipher, bufs, lengths,
472				    line_no))
473					return 4;
474			}
475
476			for (j = 0; j < NUM_TYPES; j++)
477				lengths[j] = 0;
478
479			num_tests++;
480			continue;
481		}
482
483		/*
484		 * Each line looks like:
485		 *   TYPE: 0123abc
486		 * Where "TYPE" is the type of the data on the line,
487		 * e.g. "KEY".
488		 */
489		for (i = 0; line[i] != 0 && line[i] != '\n'; i++) {
490			if (line[i] == ':') {
491				type_len = i;
492				break;
493			}
494		}
495		i++;
496
497		if (type_len == 0) {
498			fprintf(stderr, "Parse error on line %u\n", line_no);
499			return 3;
500		}
501
502		/* After the colon, there's optional whitespace. */
503		for (; line[i] != 0 && line[i] != '\n'; i++) {
504			if (line[i] != ' ' && line[i] != '\t')
505				break;
506		}
507
508		line[type_len] = 0;
509		for (j = 0; j < NUM_TYPES; j++) {
510			if (strcmp(line, NAMES[j]) != 0)
511				continue;
512			if (lengths[j] != 0) {
513				fprintf(stderr, "Duplicate value on line %u\n",
514				    line_no);
515				return 3;
516			}
517			buf = bufs[j];
518			buf_len = &lengths[j];
519			break;
520		}
521
522		if (buf == NULL) {
523			fprintf(stderr, "Unknown line type on line %u\n",
524			    line_no);
525			return 3;
526		}
527
528		if (j == AEAD) {
529			*buf_len = strlcpy(buf, line + i, BUF_MAX);
530			for (j = 0; j < BUF_MAX; j++) {
531				if (buf[j] == '\n')
532					buf[j] = '\0';
533			}
534			continue;
535		}
536
537		if (line[i] == '"') {
538			i++;
539			for (j = 0; line[i] != 0 && line[i] != '\n'; i++) {
540				if (line[i] == '"')
541					break;
542				if (j == BUF_MAX) {
543					fprintf(stderr, "Too much data on "
544					    "line %u (max is %u bytes)\n",
545					    line_no, (unsigned) BUF_MAX);
546					return 3;
547				}
548				buf[j++] = line[i];
549				*buf_len = *buf_len + 1;
550			}
551			if (line[i + 1] != 0 && line[i + 1] != '\n') {
552				fprintf(stderr, "Trailing data on line %u\n",
553				    line_no);
554				return 3;
555			}
556		} else {
557			for (j = 0; line[i] != 0 && line[i] != '\n'; i++) {
558				unsigned char v, v2;
559				v = hex_digit(line[i++]);
560				if (line[i] == 0 || line[i] == '\n') {
561					fprintf(stderr, "Odd-length hex data "
562					    "on line %u\n", line_no);
563					return 3;
564				}
565				v2 = hex_digit(line[i]);
566				if (v > 15 || v2 > 15) {
567					fprintf(stderr, "Invalid hex char on "
568					    "line %u\n", line_no);
569					return 3;
570				}
571				v <<= 4;
572				v |= v2;
573
574				if (j == BUF_MAX) {
575					fprintf(stderr, "Too much hex data on "
576					    "line %u (max is %u bytes)\n",
577					    line_no, (unsigned) BUF_MAX);
578					return 3;
579				}
580				buf[j++] = v;
581				*buf_len = *buf_len + 1;
582			}
583		}
584	}
585
586	printf("Completed %u test cases\n", num_tests);
587	printf("PASS\n");
588	fclose(f);
589
590	return 0;
591}
592