1281681Srpaulo/*
2281681Srpaulo * AES SIV (RFC 5297)
3281681Srpaulo * Copyright (c) 2013 Cozybit, Inc.
4281681Srpaulo *
5281681Srpaulo * This software may be distributed under the terms of the BSD license.
6281681Srpaulo * See README for more details.
7281681Srpaulo */
8281681Srpaulo
9281681Srpaulo#include "includes.h"
10281681Srpaulo
11281681Srpaulo#include "common.h"
12281681Srpaulo#include "aes.h"
13281681Srpaulo#include "aes_wrap.h"
14281681Srpaulo#include "aes_siv.h"
15281681Srpaulo
16281681Srpaulo
17281681Srpaulostatic const u8 zero[AES_BLOCK_SIZE];
18281681Srpaulo
19281681Srpaulo
20281681Srpaulostatic void dbl(u8 *pad)
21281681Srpaulo{
22281681Srpaulo	int i, carry;
23281681Srpaulo
24281681Srpaulo	carry = pad[0] & 0x80;
25281681Srpaulo	for (i = 0; i < AES_BLOCK_SIZE - 1; i++)
26281681Srpaulo		pad[i] = (pad[i] << 1) | (pad[i + 1] >> 7);
27281681Srpaulo	pad[AES_BLOCK_SIZE - 1] <<= 1;
28281681Srpaulo	if (carry)
29281681Srpaulo		pad[AES_BLOCK_SIZE - 1] ^= 0x87;
30281681Srpaulo}
31281681Srpaulo
32281681Srpaulo
33281681Srpaulostatic void xor(u8 *a, const u8 *b)
34281681Srpaulo{
35281681Srpaulo	int i;
36281681Srpaulo
37281681Srpaulo	for (i = 0; i < AES_BLOCK_SIZE; i++)
38281681Srpaulo		*a++ ^= *b++;
39281681Srpaulo}
40281681Srpaulo
41281681Srpaulo
42281681Srpaulostatic void xorend(u8 *a, int alen, const u8 *b, int blen)
43281681Srpaulo{
44281681Srpaulo	int i;
45281681Srpaulo
46281681Srpaulo	if (alen < blen)
47281681Srpaulo		return;
48281681Srpaulo
49281681Srpaulo	for (i = 0; i < blen; i++)
50281681Srpaulo		a[alen - blen + i] ^= b[i];
51281681Srpaulo}
52281681Srpaulo
53281681Srpaulo
54281681Srpaulostatic void pad_block(u8 *pad, const u8 *addr, size_t len)
55281681Srpaulo{
56281681Srpaulo	os_memset(pad, 0, AES_BLOCK_SIZE);
57281681Srpaulo	os_memcpy(pad, addr, len);
58281681Srpaulo
59281681Srpaulo	if (len < AES_BLOCK_SIZE)
60281681Srpaulo		pad[len] = 0x80;
61281681Srpaulo}
62281681Srpaulo
63281681Srpaulo
64281681Srpaulostatic int aes_s2v(const u8 *key, size_t num_elem, const u8 *addr[],
65281681Srpaulo		   size_t *len, u8 *mac)
66281681Srpaulo{
67281681Srpaulo	u8 tmp[AES_BLOCK_SIZE], tmp2[AES_BLOCK_SIZE];
68281681Srpaulo	u8 *buf = NULL;
69281681Srpaulo	int ret;
70281681Srpaulo	size_t i;
71281681Srpaulo
72281681Srpaulo	if (!num_elem) {
73281681Srpaulo		os_memcpy(tmp, zero, sizeof(zero));
74281681Srpaulo		tmp[AES_BLOCK_SIZE - 1] = 1;
75281681Srpaulo		return omac1_aes_128(key, tmp, sizeof(tmp), mac);
76281681Srpaulo	}
77281681Srpaulo
78281681Srpaulo	ret = omac1_aes_128(key, zero, sizeof(zero), tmp);
79281681Srpaulo	if (ret)
80281681Srpaulo		return ret;
81281681Srpaulo
82281681Srpaulo	for (i = 0; i < num_elem - 1; i++) {
83281681Srpaulo		ret = omac1_aes_128(key, addr[i], len[i], tmp2);
84281681Srpaulo		if (ret)
85281681Srpaulo			return ret;
86281681Srpaulo
87281681Srpaulo		dbl(tmp);
88281681Srpaulo		xor(tmp, tmp2);
89281681Srpaulo	}
90281681Srpaulo	if (len[i] >= AES_BLOCK_SIZE) {
91281681Srpaulo		buf = os_malloc(len[i]);
92281681Srpaulo		if (!buf)
93281681Srpaulo			return -ENOMEM;
94281681Srpaulo
95281681Srpaulo		os_memcpy(buf, addr[i], len[i]);
96281681Srpaulo		xorend(buf, len[i], tmp, AES_BLOCK_SIZE);
97281681Srpaulo		ret = omac1_aes_128(key, buf, len[i], mac);
98281681Srpaulo		bin_clear_free(buf, len[i]);
99281681Srpaulo		return ret;
100281681Srpaulo	}
101281681Srpaulo
102281681Srpaulo	dbl(tmp);
103281681Srpaulo	pad_block(tmp2, addr[i], len[i]);
104281681Srpaulo	xor(tmp, tmp2);
105281681Srpaulo
106281681Srpaulo	return omac1_aes_128(key, tmp, sizeof(tmp), mac);
107281681Srpaulo}
108281681Srpaulo
109281681Srpaulo
110281681Srpauloint aes_siv_encrypt(const u8 *key, const u8 *pw,
111281681Srpaulo		    size_t pwlen, size_t num_elem,
112281681Srpaulo		    const u8 *addr[], const size_t *len, u8 *out)
113281681Srpaulo{
114281681Srpaulo	const u8 *_addr[6];
115281681Srpaulo	size_t _len[6];
116281681Srpaulo	const u8 *k1 = key, *k2 = key + 16;
117281681Srpaulo	u8 v[AES_BLOCK_SIZE];
118281681Srpaulo	size_t i;
119281681Srpaulo	u8 *iv, *crypt_pw;
120281681Srpaulo
121281681Srpaulo	if (num_elem > ARRAY_SIZE(_addr) - 1)
122281681Srpaulo		return -1;
123281681Srpaulo
124281681Srpaulo	for (i = 0; i < num_elem; i++) {
125281681Srpaulo		_addr[i] = addr[i];
126281681Srpaulo		_len[i] = len[i];
127281681Srpaulo	}
128281681Srpaulo	_addr[num_elem] = pw;
129281681Srpaulo	_len[num_elem] = pwlen;
130281681Srpaulo
131281681Srpaulo	if (aes_s2v(k1, num_elem + 1, _addr, _len, v))
132281681Srpaulo		return -1;
133281681Srpaulo
134281681Srpaulo	iv = out;
135281681Srpaulo	crypt_pw = out + AES_BLOCK_SIZE;
136281681Srpaulo
137281681Srpaulo	os_memcpy(iv, v, AES_BLOCK_SIZE);
138281681Srpaulo	os_memcpy(crypt_pw, pw, pwlen);
139281681Srpaulo
140281681Srpaulo	/* zero out 63rd and 31st bits of ctr (from right) */
141281681Srpaulo	v[8] &= 0x7f;
142281681Srpaulo	v[12] &= 0x7f;
143281681Srpaulo	return aes_128_ctr_encrypt(k2, v, crypt_pw, pwlen);
144281681Srpaulo}
145281681Srpaulo
146281681Srpaulo
147281681Srpauloint aes_siv_decrypt(const u8 *key, const u8 *iv_crypt, size_t iv_c_len,
148281681Srpaulo		    size_t num_elem, const u8 *addr[], const size_t *len,
149281681Srpaulo		    u8 *out)
150281681Srpaulo{
151281681Srpaulo	const u8 *_addr[6];
152281681Srpaulo	size_t _len[6];
153281681Srpaulo	const u8 *k1 = key, *k2 = key + 16;
154281681Srpaulo	size_t crypt_len;
155281681Srpaulo	size_t i;
156281681Srpaulo	int ret;
157281681Srpaulo	u8 iv[AES_BLOCK_SIZE];
158281681Srpaulo	u8 check[AES_BLOCK_SIZE];
159281681Srpaulo
160281681Srpaulo	if (iv_c_len < AES_BLOCK_SIZE || num_elem > ARRAY_SIZE(_addr) - 1)
161281681Srpaulo		return -1;
162281681Srpaulo	crypt_len = iv_c_len - AES_BLOCK_SIZE;
163281681Srpaulo
164281681Srpaulo	for (i = 0; i < num_elem; i++) {
165281681Srpaulo		_addr[i] = addr[i];
166281681Srpaulo		_len[i] = len[i];
167281681Srpaulo	}
168281681Srpaulo	_addr[num_elem] = out;
169281681Srpaulo	_len[num_elem] = crypt_len;
170281681Srpaulo
171281681Srpaulo	os_memcpy(iv, iv_crypt, AES_BLOCK_SIZE);
172281681Srpaulo	os_memcpy(out, iv_crypt + AES_BLOCK_SIZE, crypt_len);
173281681Srpaulo
174281681Srpaulo	iv[8] &= 0x7f;
175281681Srpaulo	iv[12] &= 0x7f;
176281681Srpaulo
177281681Srpaulo	ret = aes_128_ctr_encrypt(k2, iv, out, crypt_len);
178281681Srpaulo	if (ret)
179281681Srpaulo		return ret;
180281681Srpaulo
181281681Srpaulo	ret = aes_s2v(k1, num_elem + 1, _addr, _len, check);
182281681Srpaulo	if (ret)
183281681Srpaulo		return ret;
184281681Srpaulo	if (os_memcmp(check, iv_crypt, AES_BLOCK_SIZE) == 0)
185281681Srpaulo		return 0;
186281681Srpaulo
187281681Srpaulo	return -1;
188281681Srpaulo}
189