1214501Srpaulo/*
2214501Srpaulo * Crypto wrapper for internal crypto implementation - Cipher wrappers
3214501Srpaulo * Copyright (c) 2006-2009, Jouni Malinen <j@w1.fi>
4214501Srpaulo *
5214501Srpaulo * This program is free software; you can redistribute it and/or modify
6214501Srpaulo * it under the terms of the GNU General Public License version 2 as
7214501Srpaulo * published by the Free Software Foundation.
8214501Srpaulo *
9214501Srpaulo * Alternatively, this software may be distributed under the terms of BSD
10214501Srpaulo * license.
11214501Srpaulo *
12214501Srpaulo * See README and COPYING for more details.
13214501Srpaulo */
14214501Srpaulo
15214501Srpaulo#include "includes.h"
16214501Srpaulo
17214501Srpaulo#include "common.h"
18214501Srpaulo#include "crypto.h"
19214501Srpaulo#include "aes.h"
20214501Srpaulo#include "des_i.h"
21214501Srpaulo
22214501Srpaulo
23214501Srpaulostruct crypto_cipher {
24214501Srpaulo	enum crypto_cipher_alg alg;
25214501Srpaulo	union {
26214501Srpaulo		struct {
27214501Srpaulo			size_t used_bytes;
28214501Srpaulo			u8 key[16];
29214501Srpaulo			size_t keylen;
30214501Srpaulo		} rc4;
31214501Srpaulo		struct {
32214501Srpaulo			u8 cbc[32];
33214501Srpaulo			size_t block_size;
34214501Srpaulo			void *ctx_enc;
35214501Srpaulo			void *ctx_dec;
36214501Srpaulo		} aes;
37214501Srpaulo		struct {
38214501Srpaulo			struct des3_key_s key;
39214501Srpaulo			u8 cbc[8];
40214501Srpaulo		} des3;
41214501Srpaulo		struct {
42214501Srpaulo			u32 ek[32];
43214501Srpaulo			u32 dk[32];
44214501Srpaulo			u8 cbc[8];
45214501Srpaulo		} des;
46214501Srpaulo	} u;
47214501Srpaulo};
48214501Srpaulo
49214501Srpaulo
50214501Srpaulostruct crypto_cipher * crypto_cipher_init(enum crypto_cipher_alg alg,
51214501Srpaulo					  const u8 *iv, const u8 *key,
52214501Srpaulo					  size_t key_len)
53214501Srpaulo{
54214501Srpaulo	struct crypto_cipher *ctx;
55214501Srpaulo
56214501Srpaulo	ctx = os_zalloc(sizeof(*ctx));
57214501Srpaulo	if (ctx == NULL)
58214501Srpaulo		return NULL;
59214501Srpaulo
60214501Srpaulo	ctx->alg = alg;
61214501Srpaulo
62214501Srpaulo	switch (alg) {
63214501Srpaulo	case CRYPTO_CIPHER_ALG_RC4:
64214501Srpaulo		if (key_len > sizeof(ctx->u.rc4.key)) {
65214501Srpaulo			os_free(ctx);
66214501Srpaulo			return NULL;
67214501Srpaulo		}
68214501Srpaulo		ctx->u.rc4.keylen = key_len;
69214501Srpaulo		os_memcpy(ctx->u.rc4.key, key, key_len);
70214501Srpaulo		break;
71214501Srpaulo	case CRYPTO_CIPHER_ALG_AES:
72214501Srpaulo		if (key_len > sizeof(ctx->u.aes.cbc)) {
73214501Srpaulo			os_free(ctx);
74214501Srpaulo			return NULL;
75214501Srpaulo		}
76214501Srpaulo		ctx->u.aes.ctx_enc = aes_encrypt_init(key, key_len);
77214501Srpaulo		if (ctx->u.aes.ctx_enc == NULL) {
78214501Srpaulo			os_free(ctx);
79214501Srpaulo			return NULL;
80214501Srpaulo		}
81214501Srpaulo		ctx->u.aes.ctx_dec = aes_decrypt_init(key, key_len);
82214501Srpaulo		if (ctx->u.aes.ctx_dec == NULL) {
83214501Srpaulo			aes_encrypt_deinit(ctx->u.aes.ctx_enc);
84214501Srpaulo			os_free(ctx);
85214501Srpaulo			return NULL;
86214501Srpaulo		}
87214501Srpaulo		ctx->u.aes.block_size = key_len;
88214501Srpaulo		os_memcpy(ctx->u.aes.cbc, iv, ctx->u.aes.block_size);
89214501Srpaulo		break;
90214501Srpaulo	case CRYPTO_CIPHER_ALG_3DES:
91214501Srpaulo		if (key_len != 24) {
92214501Srpaulo			os_free(ctx);
93214501Srpaulo			return NULL;
94214501Srpaulo		}
95214501Srpaulo		des3_key_setup(key, &ctx->u.des3.key);
96214501Srpaulo		os_memcpy(ctx->u.des3.cbc, iv, 8);
97214501Srpaulo		break;
98214501Srpaulo	case CRYPTO_CIPHER_ALG_DES:
99214501Srpaulo		if (key_len != 8) {
100214501Srpaulo			os_free(ctx);
101214501Srpaulo			return NULL;
102214501Srpaulo		}
103214501Srpaulo		des_key_setup(key, ctx->u.des.ek, ctx->u.des.dk);
104214501Srpaulo		os_memcpy(ctx->u.des.cbc, iv, 8);
105214501Srpaulo		break;
106214501Srpaulo	default:
107214501Srpaulo		os_free(ctx);
108214501Srpaulo		return NULL;
109214501Srpaulo	}
110214501Srpaulo
111214501Srpaulo	return ctx;
112214501Srpaulo}
113214501Srpaulo
114214501Srpaulo
115214501Srpauloint crypto_cipher_encrypt(struct crypto_cipher *ctx, const u8 *plain,
116214501Srpaulo			  u8 *crypt, size_t len)
117214501Srpaulo{
118214501Srpaulo	size_t i, j, blocks;
119214501Srpaulo
120214501Srpaulo	switch (ctx->alg) {
121214501Srpaulo	case CRYPTO_CIPHER_ALG_RC4:
122214501Srpaulo		if (plain != crypt)
123214501Srpaulo			os_memcpy(crypt, plain, len);
124214501Srpaulo		rc4_skip(ctx->u.rc4.key, ctx->u.rc4.keylen,
125214501Srpaulo			 ctx->u.rc4.used_bytes, crypt, len);
126214501Srpaulo		ctx->u.rc4.used_bytes += len;
127214501Srpaulo		break;
128214501Srpaulo	case CRYPTO_CIPHER_ALG_AES:
129214501Srpaulo		if (len % ctx->u.aes.block_size)
130214501Srpaulo			return -1;
131214501Srpaulo		blocks = len / ctx->u.aes.block_size;
132214501Srpaulo		for (i = 0; i < blocks; i++) {
133214501Srpaulo			for (j = 0; j < ctx->u.aes.block_size; j++)
134214501Srpaulo				ctx->u.aes.cbc[j] ^= plain[j];
135214501Srpaulo			aes_encrypt(ctx->u.aes.ctx_enc, ctx->u.aes.cbc,
136214501Srpaulo				    ctx->u.aes.cbc);
137214501Srpaulo			os_memcpy(crypt, ctx->u.aes.cbc,
138214501Srpaulo				  ctx->u.aes.block_size);
139214501Srpaulo			plain += ctx->u.aes.block_size;
140214501Srpaulo			crypt += ctx->u.aes.block_size;
141214501Srpaulo		}
142214501Srpaulo		break;
143214501Srpaulo	case CRYPTO_CIPHER_ALG_3DES:
144214501Srpaulo		if (len % 8)
145214501Srpaulo			return -1;
146214501Srpaulo		blocks = len / 8;
147214501Srpaulo		for (i = 0; i < blocks; i++) {
148214501Srpaulo			for (j = 0; j < 8; j++)
149214501Srpaulo				ctx->u.des3.cbc[j] ^= plain[j];
150214501Srpaulo			des3_encrypt(ctx->u.des3.cbc, &ctx->u.des3.key,
151214501Srpaulo				     ctx->u.des3.cbc);
152214501Srpaulo			os_memcpy(crypt, ctx->u.des3.cbc, 8);
153214501Srpaulo			plain += 8;
154214501Srpaulo			crypt += 8;
155214501Srpaulo		}
156214501Srpaulo		break;
157214501Srpaulo	case CRYPTO_CIPHER_ALG_DES:
158214501Srpaulo		if (len % 8)
159214501Srpaulo			return -1;
160214501Srpaulo		blocks = len / 8;
161214501Srpaulo		for (i = 0; i < blocks; i++) {
162214501Srpaulo			for (j = 0; j < 8; j++)
163214501Srpaulo				ctx->u.des3.cbc[j] ^= plain[j];
164214501Srpaulo			des_block_encrypt(ctx->u.des.cbc, ctx->u.des.ek,
165214501Srpaulo					  ctx->u.des.cbc);
166214501Srpaulo			os_memcpy(crypt, ctx->u.des.cbc, 8);
167214501Srpaulo			plain += 8;
168214501Srpaulo			crypt += 8;
169214501Srpaulo		}
170214501Srpaulo		break;
171214501Srpaulo	default:
172214501Srpaulo		return -1;
173214501Srpaulo	}
174214501Srpaulo
175214501Srpaulo	return 0;
176214501Srpaulo}
177214501Srpaulo
178214501Srpaulo
179214501Srpauloint crypto_cipher_decrypt(struct crypto_cipher *ctx, const u8 *crypt,
180214501Srpaulo			  u8 *plain, size_t len)
181214501Srpaulo{
182214501Srpaulo	size_t i, j, blocks;
183214501Srpaulo	u8 tmp[32];
184214501Srpaulo
185214501Srpaulo	switch (ctx->alg) {
186214501Srpaulo	case CRYPTO_CIPHER_ALG_RC4:
187214501Srpaulo		if (plain != crypt)
188214501Srpaulo			os_memcpy(plain, crypt, len);
189214501Srpaulo		rc4_skip(ctx->u.rc4.key, ctx->u.rc4.keylen,
190214501Srpaulo			 ctx->u.rc4.used_bytes, plain, len);
191214501Srpaulo		ctx->u.rc4.used_bytes += len;
192214501Srpaulo		break;
193214501Srpaulo	case CRYPTO_CIPHER_ALG_AES:
194214501Srpaulo		if (len % ctx->u.aes.block_size)
195214501Srpaulo			return -1;
196214501Srpaulo		blocks = len / ctx->u.aes.block_size;
197214501Srpaulo		for (i = 0; i < blocks; i++) {
198214501Srpaulo			os_memcpy(tmp, crypt, ctx->u.aes.block_size);
199214501Srpaulo			aes_decrypt(ctx->u.aes.ctx_dec, crypt, plain);
200214501Srpaulo			for (j = 0; j < ctx->u.aes.block_size; j++)
201214501Srpaulo				plain[j] ^= ctx->u.aes.cbc[j];
202214501Srpaulo			os_memcpy(ctx->u.aes.cbc, tmp, ctx->u.aes.block_size);
203214501Srpaulo			plain += ctx->u.aes.block_size;
204214501Srpaulo			crypt += ctx->u.aes.block_size;
205214501Srpaulo		}
206214501Srpaulo		break;
207214501Srpaulo	case CRYPTO_CIPHER_ALG_3DES:
208214501Srpaulo		if (len % 8)
209214501Srpaulo			return -1;
210214501Srpaulo		blocks = len / 8;
211214501Srpaulo		for (i = 0; i < blocks; i++) {
212214501Srpaulo			os_memcpy(tmp, crypt, 8);
213214501Srpaulo			des3_decrypt(crypt, &ctx->u.des3.key, plain);
214214501Srpaulo			for (j = 0; j < 8; j++)
215214501Srpaulo				plain[j] ^= ctx->u.des3.cbc[j];
216214501Srpaulo			os_memcpy(ctx->u.des3.cbc, tmp, 8);
217214501Srpaulo			plain += 8;
218214501Srpaulo			crypt += 8;
219214501Srpaulo		}
220214501Srpaulo		break;
221214501Srpaulo	case CRYPTO_CIPHER_ALG_DES:
222214501Srpaulo		if (len % 8)
223214501Srpaulo			return -1;
224214501Srpaulo		blocks = len / 8;
225214501Srpaulo		for (i = 0; i < blocks; i++) {
226214501Srpaulo			os_memcpy(tmp, crypt, 8);
227214501Srpaulo			des_block_decrypt(crypt, ctx->u.des.dk, plain);
228214501Srpaulo			for (j = 0; j < 8; j++)
229214501Srpaulo				plain[j] ^= ctx->u.des.cbc[j];
230214501Srpaulo			os_memcpy(ctx->u.des.cbc, tmp, 8);
231214501Srpaulo			plain += 8;
232214501Srpaulo			crypt += 8;
233214501Srpaulo		}
234214501Srpaulo		break;
235214501Srpaulo	default:
236214501Srpaulo		return -1;
237214501Srpaulo	}
238214501Srpaulo
239214501Srpaulo	return 0;
240214501Srpaulo}
241214501Srpaulo
242214501Srpaulo
243214501Srpaulovoid crypto_cipher_deinit(struct crypto_cipher *ctx)
244214501Srpaulo{
245214501Srpaulo	switch (ctx->alg) {
246214501Srpaulo	case CRYPTO_CIPHER_ALG_AES:
247214501Srpaulo		aes_encrypt_deinit(ctx->u.aes.ctx_enc);
248214501Srpaulo		aes_decrypt_deinit(ctx->u.aes.ctx_dec);
249214501Srpaulo		break;
250214501Srpaulo	case CRYPTO_CIPHER_ALG_3DES:
251214501Srpaulo		break;
252214501Srpaulo	default:
253214501Srpaulo		break;
254214501Srpaulo	}
255214501Srpaulo	os_free(ctx);
256214501Srpaulo}
257