sm2_crypt.c revision 1.1
1/*	$OpenBSD: sm2_crypt.c,v 1.1 2021/08/18 16:04:32 tb Exp $ */
2/*
3 * Copyright (c) 2017, 2019 Ribose Inc
4 *
5 * Permission to use, copy, modify, and/or distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#ifndef OPENSSL_NO_SM2
19
20#include <string.h>
21
22#include <openssl/asn1.h>
23#include <openssl/asn1t.h>
24#include <openssl/bn.h>
25#include <openssl/err.h>
26#include <openssl/evp.h>
27#include <openssl/sm2.h>
28
29#include "sm2_locl.h"
30
31typedef struct SM2_Ciphertext_st SM2_Ciphertext;
32
33SM2_Ciphertext *SM2_Ciphertext_new(void);
34void SM2_Ciphertext_free(SM2_Ciphertext *a);
35SM2_Ciphertext *d2i_SM2_Ciphertext(SM2_Ciphertext **a, const unsigned char **in,
36    long len);
37int i2d_SM2_Ciphertext(SM2_Ciphertext *a, unsigned char **out);
38
39struct SM2_Ciphertext_st {
40	BIGNUM *C1x;
41	BIGNUM *C1y;
42	ASN1_OCTET_STRING *C3;
43	ASN1_OCTET_STRING *C2;
44};
45
46static const ASN1_TEMPLATE SM2_Ciphertext_seq_tt[] = {
47	{
48		.flags = 0,
49		.tag = 0,
50		.offset = offsetof(SM2_Ciphertext, C1x),
51		.field_name = "C1x",
52		.item = &BIGNUM_it,
53	},
54	{
55		.flags = 0,
56		.tag = 0,
57		.offset = offsetof(SM2_Ciphertext, C1y),
58		.field_name = "C1y",
59		.item = &BIGNUM_it,
60	},
61	{
62		.flags = 0,
63		.tag = 0,
64		.offset = offsetof(SM2_Ciphertext, C3),
65		.field_name = "C3",
66		.item = &ASN1_OCTET_STRING_it,
67	},
68	{
69		.flags = 0,
70		.tag = 0,
71		.offset = offsetof(SM2_Ciphertext, C2),
72		.field_name = "C2",
73		.item = &ASN1_OCTET_STRING_it,
74	},
75};
76
77const ASN1_ITEM SM2_Ciphertext_it = {
78	.itype = ASN1_ITYPE_SEQUENCE,
79	.utype = V_ASN1_SEQUENCE,
80	.templates = SM2_Ciphertext_seq_tt,
81	.tcount = sizeof(SM2_Ciphertext_seq_tt) / sizeof(ASN1_TEMPLATE),
82	.funcs = NULL,
83	.size = sizeof(SM2_Ciphertext),
84	.sname = "SM2_Ciphertext",
85};
86
87SM2_Ciphertext *
88d2i_SM2_Ciphertext(SM2_Ciphertext **a, const unsigned char **in, long len)
89{
90	return (SM2_Ciphertext *) ASN1_item_d2i((ASN1_VALUE **)a, in, len,
91	    &SM2_Ciphertext_it);
92}
93
94int
95i2d_SM2_Ciphertext(SM2_Ciphertext *a, unsigned char **out)
96{
97	return ASN1_item_i2d((ASN1_VALUE *)a, out, &SM2_Ciphertext_it);
98}
99
100SM2_Ciphertext *
101SM2_Ciphertext_new(void)
102{
103	return (SM2_Ciphertext *)ASN1_item_new(&SM2_Ciphertext_it);
104}
105
106void
107SM2_Ciphertext_free(SM2_Ciphertext *a)
108{
109	ASN1_item_free((ASN1_VALUE *)a, &SM2_Ciphertext_it);
110}
111
112static size_t
113ec_field_size(const EC_GROUP *group)
114{
115	/* Is there some simpler way to do this? */
116	BIGNUM *p;
117	size_t field_size = 0;
118
119	if ((p = BN_new()) == NULL)
120		goto err;
121	if (!EC_GROUP_get_curve(group, p, NULL, NULL, NULL))
122		goto err;
123	field_size = BN_num_bytes(p);
124 err:
125	BN_free(p);
126	return field_size;
127}
128
129int
130SM2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len,
131    size_t *pl_size)
132{
133	size_t field_size, overhead;
134	int md_size;
135
136	if ((field_size = ec_field_size(EC_KEY_get0_group(key))) == 0) {
137		SM2error(SM2_R_INVALID_FIELD);
138		return 0;
139	}
140
141	if ((md_size = EVP_MD_size(digest)) < 0) {
142		SM2error(SM2_R_INVALID_DIGEST);
143		return 0;
144	}
145
146	overhead = 10 + 2 * field_size + md_size;
147	if (msg_len <= overhead) {
148		SM2error(SM2_R_INVALID_ARGUMENT);
149		return 0;
150	}
151
152	*pl_size = msg_len - overhead;
153	return 1;
154}
155
156int
157SM2_ciphertext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len,
158    size_t *c_size)
159{
160	size_t asn_size, field_size;
161	int md_size;
162
163	if ((field_size = ec_field_size(EC_KEY_get0_group(key))) == 0) {
164		SM2error(SM2_R_INVALID_FIELD);
165		return 0;
166	}
167
168	if ((md_size = EVP_MD_size(digest)) < 0) {
169		SM2error(SM2_R_INVALID_DIGEST);
170		return 0;
171	}
172
173	asn_size = 2 * ASN1_object_size(0, field_size + 1, V_ASN1_INTEGER) +
174	    ASN1_object_size(0, md_size, V_ASN1_OCTET_STRING) +
175	    ASN1_object_size(0, msg_len, V_ASN1_OCTET_STRING);
176
177	*c_size = ASN1_object_size(1, asn_size, V_ASN1_SEQUENCE);
178	return 1;
179}
180
181int
182sm2_kdf(uint8_t *key, size_t key_len, uint8_t *secret, size_t secret_len,
183    const EVP_MD *digest)
184{
185	EVP_MD_CTX *hash;
186	uint8_t *hash_buf = NULL;
187	uint32_t ctr = 1;
188	uint8_t ctr_buf[4] = {0};
189	size_t hadd, hlen;
190	int rc = 0;
191
192	if ((hash = EVP_MD_CTX_new()) == NULL) {
193		SM2error(ERR_R_MALLOC_FAILURE);
194		goto err;
195	}
196
197	if ((hlen = EVP_MD_size(digest)) < 0) {
198		SM2error(SM2_R_INVALID_DIGEST);
199		goto err;
200	}
201	if ((hash_buf = malloc(hlen)) == NULL) {
202		SM2error(ERR_R_MALLOC_FAILURE);
203		goto err;
204	}
205
206	EVP_MD_CTX_init(hash);
207	while ((key_len > 0) && (ctr != 0)) {
208		if (!EVP_DigestInit_ex(hash, digest, NULL)) {
209			SM2error(ERR_R_EVP_LIB);
210			goto err;
211		}
212		if (!EVP_DigestUpdate(hash, secret, secret_len)) {
213			SM2error(ERR_R_EVP_LIB);
214			goto err;
215		}
216
217		/* big-endian counter representation */
218		ctr_buf[0] = (ctr >> 24) & 0xff;
219		ctr_buf[1] = (ctr >> 16) & 0xff;
220		ctr_buf[2] = (ctr >> 8) & 0xff;
221		ctr_buf[3] = ctr & 0xff;
222		ctr++;
223
224		if (!EVP_DigestUpdate(hash, ctr_buf, 4)) {
225			SM2error(ERR_R_EVP_LIB);
226			goto err;
227		}
228		if (!EVP_DigestFinal(hash, hash_buf, NULL)) {
229			SM2error(ERR_R_EVP_LIB);
230			goto err;
231		}
232
233		hadd = key_len > hlen ? hlen : key_len;
234		memcpy(key, hash_buf, hadd);
235		memset(hash_buf, 0, hlen);
236		key_len -= hadd;
237		key += hadd;
238	}
239
240	rc = 1;
241 err:
242	free(hash_buf);
243	EVP_MD_CTX_free(hash);
244	return rc;
245}
246
247int
248SM2_encrypt(const EC_KEY *key, const EVP_MD *digest, const uint8_t *msg,
249    size_t msg_len, uint8_t *ciphertext_buf, size_t *ciphertext_len)
250{
251	SM2_Ciphertext ctext_struct;
252	EVP_MD_CTX *hash = NULL;
253	BN_CTX *ctx = NULL;
254	BIGNUM *order = NULL;
255	BIGNUM *k, *x1, *y1, *x2, *y2;
256	const EC_GROUP *group;
257	const EC_POINT *P;
258	EC_POINT *kG = NULL, *kP = NULL;
259	uint8_t *msg_mask = NULL, *x2y2 = NULL, *C3 = NULL;
260	size_t C3_size, field_size, i, x2size, y2size;
261	int rc = 0;
262	int clen;
263
264	ctext_struct.C2 = NULL;
265	ctext_struct.C3 = NULL;
266
267	if ((hash = EVP_MD_CTX_new()) == NULL) {
268		SM2error(ERR_R_MALLOC_FAILURE);
269		goto err;
270	}
271
272	if ((group = EC_KEY_get0_group(key)) == NULL) {
273		SM2error(SM2_R_INVALID_KEY);
274		goto err;
275	}
276
277	if ((order = BN_new()) == NULL) {
278		SM2error(ERR_R_MALLOC_FAILURE);
279		goto err;
280	}
281
282	if (!EC_GROUP_get_order(group, order, NULL)) {
283		SM2error(SM2_R_INVALID_GROUP_ORDER);
284		goto err;
285	}
286
287	if ((P = EC_KEY_get0_public_key(key)) == NULL) {
288		SM2error(SM2_R_INVALID_KEY);
289		goto err;
290	}
291
292	if ((field_size = ec_field_size(group)) == 0) {
293		SM2error(SM2_R_INVALID_FIELD);
294		goto err;
295	}
296
297	if ((C3_size = EVP_MD_size(digest)) < 0) {
298		SM2error(SM2_R_INVALID_DIGEST);
299		goto err;
300	}
301
302	if ((kG = EC_POINT_new(group)) == NULL) {
303		SM2error(ERR_R_MALLOC_FAILURE);
304		goto err;
305	}
306	if ((kP = EC_POINT_new(group)) == NULL) {
307		SM2error(ERR_R_MALLOC_FAILURE);
308		goto err;
309	}
310
311	if ((ctx = BN_CTX_new()) == NULL) {
312		SM2error(ERR_R_MALLOC_FAILURE);
313		goto err;
314	}
315
316	BN_CTX_start(ctx);
317	if ((k = BN_CTX_get(ctx)) == NULL) {
318		SM2error(ERR_R_BN_LIB);
319		goto err;
320	}
321	if ((x1 = BN_CTX_get(ctx)) == NULL) {
322		SM2error(ERR_R_BN_LIB);
323		goto err;
324	}
325	if ((x2 = BN_CTX_get(ctx)) == NULL) {
326		SM2error(ERR_R_BN_LIB);
327		goto err;
328	}
329	if ((y1 = BN_CTX_get(ctx)) == NULL) {
330		SM2error(ERR_R_BN_LIB);
331		goto err;
332	}
333	if ((y2 = BN_CTX_get(ctx)) == NULL) {
334		SM2error(ERR_R_BN_LIB);
335		goto err;
336	}
337
338	if ((x2y2 = calloc(2, field_size)) == NULL) {
339		SM2error(ERR_R_MALLOC_FAILURE);
340		goto err;
341	}
342
343	if ((C3 = calloc(1, C3_size)) == NULL) {
344		SM2error(ERR_R_MALLOC_FAILURE);
345		goto err;
346	}
347
348	memset(ciphertext_buf, 0, *ciphertext_len);
349
350	if (!BN_rand_range(k, order)) {
351		SM2error(SM2_R_RANDOM_NUMBER_GENERATION_FAILED);
352		goto err;
353	}
354
355	if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)) {
356		SM2error(ERR_R_EC_LIB);
357		goto err;
358	}
359
360	if (!EC_POINT_get_affine_coordinates(group, kG, x1, y1, ctx)) {
361		SM2error(ERR_R_EC_LIB);
362		goto err;
363	}
364
365	if (!EC_POINT_mul(group, kP, NULL, P, k, ctx)) {
366		SM2error(ERR_R_EC_LIB);
367		goto err;
368	}
369
370	if (!EC_POINT_get_affine_coordinates(group, kP, x2, y2, ctx)) {
371		SM2error(ERR_R_EC_LIB);
372		goto err;
373	}
374
375	if ((x2size = BN_num_bytes(x2)) > field_size ||
376	    (y2size = BN_num_bytes(y2)) > field_size) {
377		SM2error(SM2_R_BIGNUM_OUT_OF_RANGE);
378		goto err;
379	}
380
381	BN_bn2bin(x2, x2y2 + field_size - x2size);
382	BN_bn2bin(y2, x2y2 + 2 * field_size - y2size);
383
384	if ((msg_mask = calloc(1, msg_len)) == NULL) {
385		SM2error(ERR_R_MALLOC_FAILURE);
386		goto err;
387	}
388
389	if (!sm2_kdf(msg_mask, msg_len, x2y2, 2 * field_size, digest)) {
390		SM2error(SM2_R_KDF_FAILURE);
391		goto err;
392	}
393
394	for (i = 0; i != msg_len; i++)
395		msg_mask[i] ^= msg[i];
396
397	if (!EVP_DigestInit(hash, digest)) {
398		SM2error(ERR_R_EVP_LIB);
399		goto err;
400	}
401
402	if (!EVP_DigestUpdate(hash, x2y2, field_size)) {
403		SM2error(ERR_R_EVP_LIB);
404		goto err;
405	}
406
407	if (!EVP_DigestUpdate(hash, msg, msg_len)) {
408		SM2error(ERR_R_EVP_LIB);
409		goto err;
410	}
411
412	if (!EVP_DigestUpdate(hash, x2y2 + field_size, field_size)) {
413		SM2error(ERR_R_EVP_LIB);
414		goto err;
415	}
416
417	if (!EVP_DigestFinal(hash, C3, NULL)) {
418		SM2error(ERR_R_EVP_LIB);
419		goto err;
420	}
421
422	ctext_struct.C1x = x1;
423	ctext_struct.C1y = y1;
424	if ((ctext_struct.C3 = ASN1_OCTET_STRING_new()) == NULL) {
425		SM2error(ERR_R_MALLOC_FAILURE);
426		goto err;
427	}
428	if ((ctext_struct.C2 = ASN1_OCTET_STRING_new()) == NULL) {
429		SM2error(ERR_R_MALLOC_FAILURE);
430		goto err;
431	}
432	if (!ASN1_OCTET_STRING_set(ctext_struct.C3, C3, C3_size)) {
433		SM2error(ERR_R_INTERNAL_ERROR);
434		goto err;
435	}
436	if (!ASN1_OCTET_STRING_set(ctext_struct.C2, msg_mask, msg_len)) {
437		SM2error(ERR_R_INTERNAL_ERROR);
438		goto err;
439	}
440
441	if ((clen = i2d_SM2_Ciphertext(&ctext_struct, &ciphertext_buf)) < 0) {
442		SM2error(ERR_R_INTERNAL_ERROR);
443		goto err;
444	}
445
446	*ciphertext_len = clen;
447	rc = 1;
448
449 err:
450	ASN1_OCTET_STRING_free(ctext_struct.C2);
451	ASN1_OCTET_STRING_free(ctext_struct.C3);
452	free(msg_mask);
453	free(x2y2);
454	free(C3);
455	EVP_MD_CTX_free(hash);
456	BN_CTX_end(ctx);
457	BN_CTX_free(ctx);
458	EC_POINT_free(kG);
459	EC_POINT_free(kP);
460	BN_free(order);
461	return rc;
462}
463
464int
465SM2_decrypt(const EC_KEY *key, const EVP_MD *digest, const uint8_t *ciphertext,
466    size_t ciphertext_len, uint8_t *ptext_buf, size_t *ptext_len)
467{
468	SM2_Ciphertext *sm2_ctext = NULL;
469	EVP_MD_CTX *hash = NULL;
470	BN_CTX *ctx = NULL;
471	BIGNUM *x2, *y2;
472	const EC_GROUP *group;
473	EC_POINT *C1 = NULL;
474	const uint8_t *C2, *C3;
475	uint8_t *computed_C3 = NULL, *msg_mask = NULL, *x2y2 = NULL;
476	size_t field_size, x2size, y2size;
477	int msg_len = 0, rc = 0;
478	int hash_size, i;
479
480	if ((group = EC_KEY_get0_group(key)) == NULL) {
481		SM2error(SM2_R_INVALID_KEY);
482		goto err;
483	}
484
485	if ((field_size = ec_field_size(group)) == 0) {
486		SM2error(SM2_R_INVALID_FIELD);
487		goto err;
488	}
489
490	if ((hash_size = EVP_MD_size(digest)) < 0) {
491		SM2error(SM2_R_INVALID_DIGEST);
492		goto err;
493	}
494
495	memset(ptext_buf, 0xFF, *ptext_len);
496
497	if ((sm2_ctext = d2i_SM2_Ciphertext(NULL, &ciphertext,
498	    ciphertext_len)) == NULL) {
499		SM2error(SM2_R_ASN1_ERROR);
500		goto err;
501	}
502
503	if (sm2_ctext->C3->length != hash_size) {
504		SM2error(SM2_R_INVALID_ENCODING);
505		goto err;
506	}
507
508	C2 = sm2_ctext->C2->data;
509	C3 = sm2_ctext->C3->data;
510	msg_len = sm2_ctext->C2->length;
511
512	if ((ctx = BN_CTX_new()) == NULL) {
513		SM2error(ERR_R_MALLOC_FAILURE);
514		goto err;
515	}
516
517	BN_CTX_start(ctx);
518	if ((x2 = BN_CTX_get(ctx)) == NULL) {
519		SM2error(ERR_R_BN_LIB);
520		goto err;
521	}
522	if ((y2 = BN_CTX_get(ctx)) == NULL) {
523		SM2error(ERR_R_BN_LIB);
524		goto err;
525	}
526
527	if ((msg_mask = calloc(1, msg_len)) == NULL) {
528		SM2error(ERR_R_MALLOC_FAILURE);
529		goto err;
530	}
531	if ((x2y2 = calloc(2, field_size)) == NULL) {
532		SM2error(ERR_R_MALLOC_FAILURE);
533		goto err;
534	}
535	if ((computed_C3 = calloc(1, hash_size)) == NULL) {
536		SM2error(ERR_R_MALLOC_FAILURE);
537		goto err;
538	}
539
540	if ((C1 = EC_POINT_new(group)) == NULL) {
541		SM2error(ERR_R_MALLOC_FAILURE);
542		goto err;
543	}
544
545	if (!EC_POINT_set_affine_coordinates(group, C1, sm2_ctext->C1x,
546	    sm2_ctext->C1y, ctx))
547	{
548		SM2error(ERR_R_EC_LIB);
549		goto err;
550	}
551
552	if (!EC_POINT_mul(group, C1, NULL, C1, EC_KEY_get0_private_key(key),
553	    ctx)) {
554		SM2error(ERR_R_EC_LIB);
555		goto err;
556	}
557
558	if (!EC_POINT_get_affine_coordinates(group, C1, x2, y2, ctx)) {
559		SM2error(ERR_R_EC_LIB);
560		goto err;
561	}
562
563	if ((x2size = BN_num_bytes(x2)) > field_size ||
564	    (y2size = BN_num_bytes(y2)) > field_size) {
565		SM2error(SM2_R_BIGNUM_OUT_OF_RANGE);
566		goto err;
567	}
568
569	BN_bn2bin(x2, x2y2 + field_size - x2size);
570	BN_bn2bin(y2, x2y2 + 2 * field_size - y2size);
571
572	if (!sm2_kdf(msg_mask, msg_len, x2y2, 2 * field_size, digest)) {
573		SM2error(SM2_R_KDF_FAILURE);
574		goto err;
575	}
576
577	for (i = 0; i != msg_len; ++i)
578		ptext_buf[i] = C2[i] ^ msg_mask[i];
579
580	if ((hash = EVP_MD_CTX_new()) == NULL) {
581		SM2error(ERR_R_EVP_LIB);
582		goto err;
583	}
584
585	if (!EVP_DigestInit(hash, digest)) {
586		SM2error(ERR_R_EVP_LIB);
587		goto err;
588	}
589
590	if (!EVP_DigestUpdate(hash, x2y2, field_size)) {
591		SM2error(ERR_R_EVP_LIB);
592		goto err;
593	}
594
595	if (!EVP_DigestUpdate(hash, ptext_buf, msg_len)) {
596		SM2error(ERR_R_EVP_LIB);
597		goto err;
598	}
599
600	if (!EVP_DigestUpdate(hash, x2y2 + field_size, field_size)) {
601		SM2error(ERR_R_EVP_LIB);
602		goto err;
603	}
604
605	if (!EVP_DigestFinal(hash, computed_C3, NULL)) {
606		SM2error(ERR_R_EVP_LIB);
607		goto err;
608	}
609
610	if (memcmp(computed_C3, C3, hash_size) != 0)
611		goto err;
612
613	rc = 1;
614	*ptext_len = msg_len;
615
616 err:
617	if (rc == 0)
618		memset(ptext_buf, 0, *ptext_len);
619
620	free(msg_mask);
621	free(x2y2);
622	free(computed_C3);
623	EC_POINT_free(C1);
624	BN_CTX_end(ctx);
625	BN_CTX_free(ctx);
626	SM2_Ciphertext_free(sm2_ctext);
627	EVP_MD_CTX_free(hash);
628
629	return rc;
630}
631
632#endif /* OPENSSL_NO_SM2 */
633