1/*
2 * Copyright 2017-2023 The OpenSSL Project Authors. All Rights Reserved.
3 * Copyright 2017 Ribose Inc. All Rights Reserved.
4 * Ported from Ribose contributions from Botan.
5 *
6 * Licensed under the Apache License 2.0 (the "License").  You may not use
7 * this file except in compliance with the License.  You can obtain a copy
8 * in the file LICENSE in the source distribution or at
9 * https://www.openssl.org/source/license.html
10 */
11
12#include "internal/deprecated.h"
13
14#include "crypto/sm2.h"
15#include "crypto/sm2err.h"
16#include "crypto/ec.h" /* ossl_ec_group_do_inverse_ord() */
17#include "internal/numbers.h"
18#include <openssl/err.h>
19#include <openssl/evp.h>
20#include <openssl/err.h>
21#include <openssl/bn.h>
22#include <string.h>
23
24int ossl_sm2_compute_z_digest(uint8_t *out,
25                              const EVP_MD *digest,
26                              const uint8_t *id,
27                              const size_t id_len,
28                              const EC_KEY *key)
29{
30    int rc = 0;
31    const EC_GROUP *group = EC_KEY_get0_group(key);
32    BN_CTX *ctx = NULL;
33    EVP_MD_CTX *hash = NULL;
34    BIGNUM *p = NULL;
35    BIGNUM *a = NULL;
36    BIGNUM *b = NULL;
37    BIGNUM *xG = NULL;
38    BIGNUM *yG = NULL;
39    BIGNUM *xA = NULL;
40    BIGNUM *yA = NULL;
41    int p_bytes = 0;
42    uint8_t *buf = NULL;
43    uint16_t entl = 0;
44    uint8_t e_byte = 0;
45
46    hash = EVP_MD_CTX_new();
47    ctx = BN_CTX_new_ex(ossl_ec_key_get_libctx(key));
48    if (hash == NULL || ctx == NULL) {
49        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
50        goto done;
51    }
52
53    p = BN_CTX_get(ctx);
54    a = BN_CTX_get(ctx);
55    b = BN_CTX_get(ctx);
56    xG = BN_CTX_get(ctx);
57    yG = BN_CTX_get(ctx);
58    xA = BN_CTX_get(ctx);
59    yA = BN_CTX_get(ctx);
60
61    if (yA == NULL) {
62        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
63        goto done;
64    }
65
66    if (!EVP_DigestInit(hash, digest)) {
67        ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
68        goto done;
69    }
70
71    /* Z = h(ENTL || ID || a || b || xG || yG || xA || yA) */
72
73    if (id_len >= (UINT16_MAX / 8)) {
74        /* too large */
75        ERR_raise(ERR_LIB_SM2, SM2_R_ID_TOO_LARGE);
76        goto done;
77    }
78
79    entl = (uint16_t)(8 * id_len);
80
81    e_byte = entl >> 8;
82    if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
83        ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
84        goto done;
85    }
86    e_byte = entl & 0xFF;
87    if (!EVP_DigestUpdate(hash, &e_byte, 1)) {
88        ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
89        goto done;
90    }
91
92    if (id_len > 0 && !EVP_DigestUpdate(hash, id, id_len)) {
93        ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
94        goto done;
95    }
96
97    if (!EC_GROUP_get_curve(group, p, a, b, ctx)) {
98        ERR_raise(ERR_LIB_SM2, ERR_R_EC_LIB);
99        goto done;
100    }
101
102    p_bytes = BN_num_bytes(p);
103    buf = OPENSSL_zalloc(p_bytes);
104    if (buf == NULL) {
105        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
106        goto done;
107    }
108
109    if (BN_bn2binpad(a, buf, p_bytes) < 0
110            || !EVP_DigestUpdate(hash, buf, p_bytes)
111            || BN_bn2binpad(b, buf, p_bytes) < 0
112            || !EVP_DigestUpdate(hash, buf, p_bytes)
113            || !EC_POINT_get_affine_coordinates(group,
114                                                EC_GROUP_get0_generator(group),
115                                                xG, yG, ctx)
116            || BN_bn2binpad(xG, buf, p_bytes) < 0
117            || !EVP_DigestUpdate(hash, buf, p_bytes)
118            || BN_bn2binpad(yG, buf, p_bytes) < 0
119            || !EVP_DigestUpdate(hash, buf, p_bytes)
120            || !EC_POINT_get_affine_coordinates(group,
121                                                EC_KEY_get0_public_key(key),
122                                                xA, yA, ctx)
123            || BN_bn2binpad(xA, buf, p_bytes) < 0
124            || !EVP_DigestUpdate(hash, buf, p_bytes)
125            || BN_bn2binpad(yA, buf, p_bytes) < 0
126            || !EVP_DigestUpdate(hash, buf, p_bytes)
127            || !EVP_DigestFinal(hash, out, NULL)) {
128        ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
129        goto done;
130    }
131
132    rc = 1;
133
134 done:
135    OPENSSL_free(buf);
136    BN_CTX_free(ctx);
137    EVP_MD_CTX_free(hash);
138    return rc;
139}
140
141static BIGNUM *sm2_compute_msg_hash(const EVP_MD *digest,
142                                    const EC_KEY *key,
143                                    const uint8_t *id,
144                                    const size_t id_len,
145                                    const uint8_t *msg, size_t msg_len)
146{
147    EVP_MD_CTX *hash = EVP_MD_CTX_new();
148    const int md_size = EVP_MD_get_size(digest);
149    uint8_t *z = NULL;
150    BIGNUM *e = NULL;
151    EVP_MD *fetched_digest = NULL;
152    OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
153    const char *propq = ossl_ec_key_get0_propq(key);
154
155    if (md_size < 0) {
156        ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_DIGEST);
157        goto done;
158    }
159
160    z = OPENSSL_zalloc(md_size);
161    if (hash == NULL || z == NULL) {
162        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
163        goto done;
164    }
165
166    fetched_digest = EVP_MD_fetch(libctx, EVP_MD_get0_name(digest), propq);
167    if (fetched_digest == NULL) {
168        ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
169        goto done;
170    }
171
172    if (!ossl_sm2_compute_z_digest(z, fetched_digest, id, id_len, key)) {
173        /* SM2err already called */
174        goto done;
175    }
176
177    if (!EVP_DigestInit(hash, fetched_digest)
178            || !EVP_DigestUpdate(hash, z, md_size)
179            || !EVP_DigestUpdate(hash, msg, msg_len)
180               /* reuse z buffer to hold H(Z || M) */
181            || !EVP_DigestFinal(hash, z, NULL)) {
182        ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
183        goto done;
184    }
185
186    e = BN_bin2bn(z, md_size, NULL);
187    if (e == NULL)
188        ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
189
190 done:
191    EVP_MD_free(fetched_digest);
192    OPENSSL_free(z);
193    EVP_MD_CTX_free(hash);
194    return e;
195}
196
197static ECDSA_SIG *sm2_sig_gen(const EC_KEY *key, const BIGNUM *e)
198{
199    const BIGNUM *dA = EC_KEY_get0_private_key(key);
200    const EC_GROUP *group = EC_KEY_get0_group(key);
201    const BIGNUM *order = EC_GROUP_get0_order(group);
202    ECDSA_SIG *sig = NULL;
203    EC_POINT *kG = NULL;
204    BN_CTX *ctx = NULL;
205    BIGNUM *k = NULL;
206    BIGNUM *rk = NULL;
207    BIGNUM *r = NULL;
208    BIGNUM *s = NULL;
209    BIGNUM *x1 = NULL;
210    BIGNUM *tmp = NULL;
211    OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
212
213    kG = EC_POINT_new(group);
214    ctx = BN_CTX_new_ex(libctx);
215    if (kG == NULL || ctx == NULL) {
216        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
217        goto done;
218    }
219
220    BN_CTX_start(ctx);
221    k = BN_CTX_get(ctx);
222    rk = BN_CTX_get(ctx);
223    x1 = BN_CTX_get(ctx);
224    tmp = BN_CTX_get(ctx);
225    if (tmp == NULL) {
226        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
227        goto done;
228    }
229
230    /*
231     * These values are returned and so should not be allocated out of the
232     * context
233     */
234    r = BN_new();
235    s = BN_new();
236
237    if (r == NULL || s == NULL) {
238        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
239        goto done;
240    }
241
242    /*
243     * A3: Generate a random number k in [1,n-1] using random number generators;
244     * A4: Compute (x1,y1)=[k]G, and convert the type of data x1 to be integer
245     *     as specified in clause 4.2.8 of GM/T 0003.1-2012;
246     * A5: Compute r=(e+x1) mod n. If r=0 or r+k=n, then go to A3;
247     * A6: Compute s=(1/(1+dA)*(k-r*dA)) mod n. If s=0, then go to A3;
248     * A7: Convert the type of data (r,s) to be bit strings according to the details
249     *     in clause 4.2.2 of GM/T 0003.1-2012. Then the signature of message M is (r,s).
250     */
251    for (;;) {
252        if (!BN_priv_rand_range_ex(k, order, 0, ctx)) {
253            ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
254            goto done;
255        }
256
257        if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)
258                || !EC_POINT_get_affine_coordinates(group, kG, x1, NULL,
259                                                    ctx)
260                || !BN_mod_add(r, e, x1, order, ctx)) {
261            ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
262            goto done;
263        }
264
265        /* try again if r == 0 or r+k == n */
266        if (BN_is_zero(r))
267            continue;
268
269        if (!BN_add(rk, r, k)) {
270            ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
271            goto done;
272        }
273
274        if (BN_cmp(rk, order) == 0)
275            continue;
276
277        if (!BN_add(s, dA, BN_value_one())
278                || !ossl_ec_group_do_inverse_ord(group, s, s, ctx)
279                || !BN_mod_mul(tmp, dA, r, order, ctx)
280                || !BN_sub(tmp, k, tmp)
281                || !BN_mod_mul(s, s, tmp, order, ctx)) {
282            ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
283            goto done;
284        }
285
286        /* try again if s == 0 */
287        if (BN_is_zero(s))
288            continue;
289
290        sig = ECDSA_SIG_new();
291        if (sig == NULL) {
292            ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
293            goto done;
294        }
295
296         /* takes ownership of r and s */
297        ECDSA_SIG_set0(sig, r, s);
298        break;
299    }
300
301 done:
302    if (sig == NULL) {
303        BN_free(r);
304        BN_free(s);
305    }
306
307    BN_CTX_free(ctx);
308    EC_POINT_free(kG);
309    return sig;
310}
311
312static int sm2_sig_verify(const EC_KEY *key, const ECDSA_SIG *sig,
313                          const BIGNUM *e)
314{
315    int ret = 0;
316    const EC_GROUP *group = EC_KEY_get0_group(key);
317    const BIGNUM *order = EC_GROUP_get0_order(group);
318    BN_CTX *ctx = NULL;
319    EC_POINT *pt = NULL;
320    BIGNUM *t = NULL;
321    BIGNUM *x1 = NULL;
322    const BIGNUM *r = NULL;
323    const BIGNUM *s = NULL;
324    OSSL_LIB_CTX *libctx = ossl_ec_key_get_libctx(key);
325
326    ctx = BN_CTX_new_ex(libctx);
327    pt = EC_POINT_new(group);
328    if (ctx == NULL || pt == NULL) {
329        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
330        goto done;
331    }
332
333    BN_CTX_start(ctx);
334    t = BN_CTX_get(ctx);
335    x1 = BN_CTX_get(ctx);
336    if (x1 == NULL) {
337        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
338        goto done;
339    }
340
341    /*
342     * B1: verify whether r' in [1,n-1], verification failed if not
343     * B2: verify whether s' in [1,n-1], verification failed if not
344     * B3: set M'~=ZA || M'
345     * B4: calculate e'=Hv(M'~)
346     * B5: calculate t = (r' + s') modn, verification failed if t=0
347     * B6: calculate the point (x1', y1')=[s']G + [t]PA
348     * B7: calculate R=(e'+x1') modn, verification pass if yes, otherwise failed
349     */
350
351    ECDSA_SIG_get0(sig, &r, &s);
352
353    if (BN_cmp(r, BN_value_one()) < 0
354            || BN_cmp(s, BN_value_one()) < 0
355            || BN_cmp(order, r) <= 0
356            || BN_cmp(order, s) <= 0) {
357        ERR_raise(ERR_LIB_SM2, SM2_R_BAD_SIGNATURE);
358        goto done;
359    }
360
361    if (!BN_mod_add(t, r, s, order, ctx)) {
362        ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
363        goto done;
364    }
365
366    if (BN_is_zero(t)) {
367        ERR_raise(ERR_LIB_SM2, SM2_R_BAD_SIGNATURE);
368        goto done;
369    }
370
371    if (!EC_POINT_mul(group, pt, s, EC_KEY_get0_public_key(key), t, ctx)
372            || !EC_POINT_get_affine_coordinates(group, pt, x1, NULL, ctx)) {
373        ERR_raise(ERR_LIB_SM2, ERR_R_EC_LIB);
374        goto done;
375    }
376
377    if (!BN_mod_add(t, e, x1, order, ctx)) {
378        ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
379        goto done;
380    }
381
382    if (BN_cmp(r, t) == 0)
383        ret = 1;
384
385 done:
386    EC_POINT_free(pt);
387    BN_CTX_free(ctx);
388    return ret;
389}
390
391ECDSA_SIG *ossl_sm2_do_sign(const EC_KEY *key,
392                            const EVP_MD *digest,
393                            const uint8_t *id,
394                            const size_t id_len,
395                            const uint8_t *msg, size_t msg_len)
396{
397    BIGNUM *e = NULL;
398    ECDSA_SIG *sig = NULL;
399
400    e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
401    if (e == NULL) {
402        /* SM2err already called */
403        goto done;
404    }
405
406    sig = sm2_sig_gen(key, e);
407
408 done:
409    BN_free(e);
410    return sig;
411}
412
413int ossl_sm2_do_verify(const EC_KEY *key,
414                       const EVP_MD *digest,
415                       const ECDSA_SIG *sig,
416                       const uint8_t *id,
417                       const size_t id_len,
418                       const uint8_t *msg, size_t msg_len)
419{
420    BIGNUM *e = NULL;
421    int ret = 0;
422
423    e = sm2_compute_msg_hash(digest, key, id, id_len, msg, msg_len);
424    if (e == NULL) {
425        /* SM2err already called */
426        goto done;
427    }
428
429    ret = sm2_sig_verify(key, sig, e);
430
431 done:
432    BN_free(e);
433    return ret;
434}
435
436int ossl_sm2_internal_sign(const unsigned char *dgst, int dgstlen,
437                           unsigned char *sig, unsigned int *siglen,
438                           EC_KEY *eckey)
439{
440    BIGNUM *e = NULL;
441    ECDSA_SIG *s = NULL;
442    int sigleni;
443    int ret = -1;
444
445    e = BN_bin2bn(dgst, dgstlen, NULL);
446    if (e == NULL) {
447       ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
448       goto done;
449    }
450
451    s = sm2_sig_gen(eckey, e);
452    if (s == NULL) {
453        ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
454        goto done;
455    }
456
457    sigleni = i2d_ECDSA_SIG(s, sig != NULL ? &sig : NULL);
458    if (sigleni < 0) {
459       ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
460       goto done;
461    }
462    *siglen = (unsigned int)sigleni;
463
464    ret = 1;
465
466 done:
467    ECDSA_SIG_free(s);
468    BN_free(e);
469    return ret;
470}
471
472int ossl_sm2_internal_verify(const unsigned char *dgst, int dgstlen,
473                             const unsigned char *sig, int sig_len,
474                             EC_KEY *eckey)
475{
476    ECDSA_SIG *s = NULL;
477    BIGNUM *e = NULL;
478    const unsigned char *p = sig;
479    unsigned char *der = NULL;
480    int derlen = -1;
481    int ret = -1;
482
483    s = ECDSA_SIG_new();
484    if (s == NULL) {
485        ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
486        goto done;
487    }
488    if (d2i_ECDSA_SIG(&s, &p, sig_len) == NULL) {
489        ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_ENCODING);
490        goto done;
491    }
492    /* Ensure signature uses DER and doesn't have trailing garbage */
493    derlen = i2d_ECDSA_SIG(s, &der);
494    if (derlen != sig_len || memcmp(sig, der, derlen) != 0) {
495        ERR_raise(ERR_LIB_SM2, SM2_R_INVALID_ENCODING);
496        goto done;
497    }
498
499    e = BN_bin2bn(dgst, dgstlen, NULL);
500    if (e == NULL) {
501        ERR_raise(ERR_LIB_SM2, ERR_R_BN_LIB);
502        goto done;
503    }
504
505    ret = sm2_sig_verify(eckey, s, e);
506
507 done:
508    OPENSSL_free(der);
509    BN_free(e);
510    ECDSA_SIG_free(s);
511    return ret;
512}
513