1// SPDX-License-Identifier: GPL-2.0+
2/*
3 * Copyright (c) 2019 Philippe Reynes <philippe.reynes@softathome.com>
4 *
5 * Unit tests for aes functions
6 */
7
8#include <command.h>
9#include <hexdump.h>
10#include <rand.h>
11#include <uboot_aes.h>
12#include <test/lib.h>
13#include <test/test.h>
14#include <test/ut.h>
15
16#define TEST_AES_ONE_BLOCK		0
17#define TEST_AES_CBC_CHAIN		1
18
19struct test_aes_s {
20	int key_len;
21	int key_exp_len;
22	int type;
23	int num_block;
24};
25
26static struct test_aes_s test_aes[] = {
27	{ AES128_KEY_LENGTH, AES128_EXPAND_KEY_LENGTH, TEST_AES_ONE_BLOCK,  1 },
28	{ AES128_KEY_LENGTH, AES128_EXPAND_KEY_LENGTH, TEST_AES_CBC_CHAIN, 16 },
29	{ AES192_KEY_LENGTH, AES192_EXPAND_KEY_LENGTH, TEST_AES_ONE_BLOCK,  1 },
30	{ AES192_KEY_LENGTH, AES192_EXPAND_KEY_LENGTH, TEST_AES_CBC_CHAIN, 16 },
31	{ AES256_KEY_LENGTH, AES256_EXPAND_KEY_LENGTH, TEST_AES_ONE_BLOCK,  1 },
32	{ AES256_KEY_LENGTH, AES256_EXPAND_KEY_LENGTH, TEST_AES_CBC_CHAIN, 16 },
33};
34
35static void rand_buf(u8 *buf, int size)
36{
37	int i;
38
39	for (i = 0; i < size; i++)
40		buf[i] = rand() & 0xff;
41}
42
43static int lib_test_aes_one_block(struct unit_test_state *uts, int key_len,
44				  u8 *key_exp, u8 *iv, int num_block,
45				  u8 *nocipher, u8 *ciphered, u8 *uncipher)
46{
47	aes_encrypt(key_len, nocipher, key_exp, ciphered);
48	aes_decrypt(key_len, ciphered, key_exp, uncipher);
49
50	ut_asserteq_mem(nocipher, uncipher, AES_BLOCK_LENGTH);
51
52	/* corrupt the expanded key */
53	key_exp[0]++;
54	aes_decrypt(key_len, ciphered, key_exp, uncipher);
55	ut_assertf(memcmp(nocipher, uncipher, AES_BLOCK_LENGTH),
56		   "nocipher and uncipher should be different\n");
57
58	return 0;
59}
60
61static int lib_test_aes_cbc_chain(struct unit_test_state *uts, int key_len,
62				  u8 *key_exp, u8 *iv, int num_block,
63				  u8 *nocipher, u8 *ciphered, u8 *uncipher)
64{
65	aes_cbc_encrypt_blocks(key_len, key_exp, iv,
66			       nocipher, ciphered, num_block);
67	aes_cbc_decrypt_blocks(key_len, key_exp, iv,
68			       ciphered, uncipher, num_block);
69
70	ut_asserteq_mem(nocipher, uncipher, num_block * AES_BLOCK_LENGTH);
71
72	/* corrupt the expanded key */
73	key_exp[0]++;
74	aes_cbc_decrypt_blocks(key_len, key_exp, iv,
75			       ciphered, uncipher, num_block);
76	ut_assertf(memcmp(nocipher, uncipher, num_block * AES_BLOCK_LENGTH),
77		   "nocipher and uncipher should be different\n");
78
79	return 0;
80}
81
82static int _lib_test_aes_run(struct unit_test_state *uts, int key_len,
83			     int key_exp_len, int type, int num_block)
84{
85	u8 *key, *key_exp, *iv;
86	u8 *nocipher, *ciphered, *uncipher;
87	int ret;
88
89	/* Allocate all the buffer */
90	key = malloc(key_len);
91	key_exp = malloc(key_exp_len);
92	iv = malloc(AES_BLOCK_LENGTH);
93	nocipher = malloc(num_block * AES_BLOCK_LENGTH);
94	ciphered = malloc((num_block + 1) * AES_BLOCK_LENGTH);
95	uncipher = malloc((num_block + 1) * AES_BLOCK_LENGTH);
96
97	if (!key || !key_exp || !iv || !nocipher || !ciphered || !uncipher) {
98		printf("%s: can't allocate memory\n", __func__);
99		ret = -1;
100		goto out;
101	}
102
103	/* Initialize all buffer */
104	rand_buf(key, key_len);
105	rand_buf(iv, AES_BLOCK_LENGTH);
106	rand_buf(nocipher, num_block * AES_BLOCK_LENGTH);
107	memset(ciphered, 0, (num_block + 1) * AES_BLOCK_LENGTH);
108	memset(uncipher, 0, (num_block + 1) * AES_BLOCK_LENGTH);
109
110	/* Expand the key */
111	aes_expand_key(key, key_len, key_exp);
112
113	/* Encrypt and decrypt */
114	switch (type) {
115	case TEST_AES_ONE_BLOCK:
116		ret = lib_test_aes_one_block(uts, key_len, key_exp, iv,
117					     num_block, nocipher,
118					     ciphered, uncipher);
119		break;
120	case TEST_AES_CBC_CHAIN:
121		ret = lib_test_aes_cbc_chain(uts, key_len, key_exp, iv,
122					     num_block, nocipher,
123					     ciphered, uncipher);
124		break;
125	default:
126		printf("%s: unknown type (type=%d)\n", __func__, type);
127		ret = -1;
128	};
129
130 out:
131	/* Free all the data */
132	free(key);
133	free(key_exp);
134	free(iv);
135	free(nocipher);
136	free(ciphered);
137	free(uncipher);
138
139	return ret;
140}
141
142static int lib_test_aes_run(struct unit_test_state *uts,
143			    struct test_aes_s *test)
144{
145	int key_len = test->key_len;
146	int key_exp_len = test->key_exp_len;
147	int type = test->type;
148	int num_block = test->num_block;
149
150	return _lib_test_aes_run(uts, key_len, key_exp_len,
151				 type, num_block);
152}
153
154static int lib_test_aes(struct unit_test_state *uts)
155{
156	int i, ret = 0;
157
158	for (i = 0; i < ARRAY_SIZE(test_aes); i++) {
159		ret = lib_test_aes_run(uts, &test_aes[i]);
160		if (ret)
161			break;
162	}
163
164	return ret;
165}
166
167LIB_TEST(lib_test_aes, 0);
168