1/*
2 * Copyright 2018-2024 The OpenSSL Project Authors. All Rights Reserved.
3 * Copyright (c) 2018-2019, Oracle and/or its affiliates.  All rights reserved.
4 *
5 * Licensed under the Apache License 2.0 (the "License").  You may not use
6 * this file except in compliance with the License.  You can obtain a copy
7 * in the file LICENSE in the source distribution or at
8 * https://www.openssl.org/source/license.html
9 */
10
11#include <openssl/err.h>
12#include <openssl/bn.h>
13#include "crypto/bn.h"
14#include "rsa_local.h"
15
16/*
17 * Part of the RSA keypair test.
18 * Check the Chinese Remainder Theorem components are valid.
19 *
20 * See SP800-5bBr1
21 *   6.4.1.2.3: rsakpv1-crt Step 7
22 *   6.4.1.3.3: rsakpv2-crt Step 7
23 */
24int ossl_rsa_check_crt_components(const RSA *rsa, BN_CTX *ctx)
25{
26    int ret = 0;
27    BIGNUM *r = NULL, *p1 = NULL, *q1 = NULL;
28
29    /* check if only some of the crt components are set */
30    if (rsa->dmp1 == NULL || rsa->dmq1 == NULL || rsa->iqmp == NULL) {
31        if (rsa->dmp1 != NULL || rsa->dmq1 != NULL || rsa->iqmp != NULL)
32            return 0;
33        return 1; /* return ok if all components are NULL */
34    }
35
36    BN_CTX_start(ctx);
37    r = BN_CTX_get(ctx);
38    p1 = BN_CTX_get(ctx);
39    q1 = BN_CTX_get(ctx);
40    if (q1 != NULL) {
41        BN_set_flags(r, BN_FLG_CONSTTIME);
42        BN_set_flags(p1, BN_FLG_CONSTTIME);
43        BN_set_flags(q1, BN_FLG_CONSTTIME);
44        ret = 1;
45    } else {
46        ret = 0;
47    }
48    ret = ret
49          /* p1 = p -1 */
50          && (BN_copy(p1, rsa->p) != NULL)
51          && BN_sub_word(p1, 1)
52          /* q1 = q - 1 */
53          && (BN_copy(q1, rsa->q) != NULL)
54          && BN_sub_word(q1, 1)
55          /* (a) 1 < dP < (p ��� 1). */
56          && (BN_cmp(rsa->dmp1, BN_value_one()) > 0)
57          && (BN_cmp(rsa->dmp1, p1) < 0)
58          /* (b) 1 < dQ < (q - 1). */
59          && (BN_cmp(rsa->dmq1, BN_value_one()) > 0)
60          && (BN_cmp(rsa->dmq1, q1) < 0)
61          /* (c) 1 < qInv < p */
62          && (BN_cmp(rsa->iqmp, BN_value_one()) > 0)
63          && (BN_cmp(rsa->iqmp, rsa->p) < 0)
64          /* (d) 1 = (dP . e) mod (p - 1)*/
65          && BN_mod_mul(r, rsa->dmp1, rsa->e, p1, ctx)
66          && BN_is_one(r)
67          /* (e) 1 = (dQ . e) mod (q - 1) */
68          && BN_mod_mul(r, rsa->dmq1, rsa->e, q1, ctx)
69          && BN_is_one(r)
70          /* (f) 1 = (qInv . q) mod p */
71          && BN_mod_mul(r, rsa->iqmp, rsa->q, rsa->p, ctx)
72          && BN_is_one(r);
73    BN_clear(r);
74    BN_clear(p1);
75    BN_clear(q1);
76    BN_CTX_end(ctx);
77    return ret;
78}
79
80/*
81 * Part of the RSA keypair test.
82 * Check that (���2)(2^(nbits/2 - 1) <= p <= 2^(nbits/2) - 1
83 *
84 * See SP800-5bBr1 6.4.1.2.1 Part 5 (c) & (g) - used for both p and q.
85 *
86 * (���2)(2^(nbits/2 - 1) = (���2/2)(2^(nbits/2))
87 */
88int ossl_rsa_check_prime_factor_range(const BIGNUM *p, int nbits, BN_CTX *ctx)
89{
90    int ret = 0;
91    BIGNUM *low;
92    int shift;
93
94    nbits >>= 1;
95    shift = nbits - BN_num_bits(&ossl_bn_inv_sqrt_2);
96
97    /* Upper bound check */
98    if (BN_num_bits(p) != nbits)
99        return 0;
100
101    BN_CTX_start(ctx);
102    low = BN_CTX_get(ctx);
103    if (low == NULL)
104        goto err;
105
106    /* set low = (���2)(2^(nbits/2 - 1) */
107    if (!BN_copy(low, &ossl_bn_inv_sqrt_2))
108        goto err;
109
110    if (shift >= 0) {
111        /*
112         * We don't have all the bits. ossl_bn_inv_sqrt_2 contains a rounded up
113         * value, so there is a very low probability that we'll reject a valid
114         * value.
115         */
116        if (!BN_lshift(low, low, shift))
117            goto err;
118    } else if (!BN_rshift(low, low, -shift)) {
119        goto err;
120    }
121    if (BN_cmp(p, low) <= 0)
122        goto err;
123    ret = 1;
124err:
125    BN_CTX_end(ctx);
126    return ret;
127}
128
129/*
130 * Part of the RSA keypair test.
131 * Check the prime factor (for either p or q)
132 * i.e: p is prime AND GCD(p - 1, e) = 1
133 *
134 * See SP800-56Br1 6.4.1.2.3 Step 5 (a to d) & (e to h).
135 */
136int ossl_rsa_check_prime_factor(BIGNUM *p, BIGNUM *e, int nbits, BN_CTX *ctx)
137{
138    int ret = 0;
139    BIGNUM *p1 = NULL, *gcd = NULL;
140
141    /* (Steps 5 a-b) prime test */
142    if (BN_check_prime(p, ctx, NULL) != 1
143            /* (Step 5c) (���2)(2^(nbits/2 - 1) <= p <= 2^(nbits/2 - 1) */
144            || ossl_rsa_check_prime_factor_range(p, nbits, ctx) != 1)
145        return 0;
146
147    BN_CTX_start(ctx);
148    p1 = BN_CTX_get(ctx);
149    gcd = BN_CTX_get(ctx);
150    if (gcd != NULL) {
151        BN_set_flags(p1, BN_FLG_CONSTTIME);
152        BN_set_flags(gcd, BN_FLG_CONSTTIME);
153        ret = 1;
154    } else {
155        ret = 0;
156    }
157    ret = ret
158          /* (Step 5d) GCD(p-1, e) = 1 */
159          && (BN_copy(p1, p) != NULL)
160          && BN_sub_word(p1, 1)
161          && BN_gcd(gcd, p1, e, ctx)
162          && BN_is_one(gcd);
163
164    BN_clear(p1);
165    BN_CTX_end(ctx);
166    return ret;
167}
168
169/*
170 * See SP800-56Br1 6.4.1.2.3 Part 6(a-b) Check the private exponent d
171 * satisfies:
172 *     (Step 6a) 2^(nBit/2) < d < LCM(p���1, q���1).
173 *     (Step 6b) 1 = (d*e) mod LCM(p���1, q���1)
174 */
175int ossl_rsa_check_private_exponent(const RSA *rsa, int nbits, BN_CTX *ctx)
176{
177    int ret;
178    BIGNUM *r, *p1, *q1, *lcm, *p1q1, *gcd;
179
180    /* (Step 6a) 2^(nbits/2) < d */
181    if (BN_num_bits(rsa->d) <= (nbits >> 1))
182        return 0;
183
184    BN_CTX_start(ctx);
185    r = BN_CTX_get(ctx);
186    p1 = BN_CTX_get(ctx);
187    q1 = BN_CTX_get(ctx);
188    lcm = BN_CTX_get(ctx);
189    p1q1 = BN_CTX_get(ctx);
190    gcd = BN_CTX_get(ctx);
191    if (gcd != NULL) {
192        BN_set_flags(r, BN_FLG_CONSTTIME);
193        BN_set_flags(p1, BN_FLG_CONSTTIME);
194        BN_set_flags(q1, BN_FLG_CONSTTIME);
195        BN_set_flags(lcm, BN_FLG_CONSTTIME);
196        BN_set_flags(p1q1, BN_FLG_CONSTTIME);
197        BN_set_flags(gcd, BN_FLG_CONSTTIME);
198        ret = 1;
199    } else {
200        ret = 0;
201    }
202    ret = (ret
203          /* LCM(p - 1, q - 1) */
204          && (ossl_rsa_get_lcm(ctx, rsa->p, rsa->q, lcm, gcd, p1, q1,
205                               p1q1) == 1)
206          /* (Step 6a) d < LCM(p - 1, q - 1) */
207          && (BN_cmp(rsa->d, lcm) < 0)
208          /* (Step 6b) 1 = (e . d) mod LCM(p - 1, q - 1) */
209          && BN_mod_mul(r, rsa->e, rsa->d, lcm, ctx)
210          && BN_is_one(r));
211
212    BN_clear(r);
213    BN_clear(p1);
214    BN_clear(q1);
215    BN_clear(lcm);
216    BN_clear(gcd);
217    BN_CTX_end(ctx);
218    return ret;
219}
220
221/*
222 * Check exponent is odd.
223 * For FIPS also check the bit length is in the range [17..256]
224 */
225int ossl_rsa_check_public_exponent(const BIGNUM *e)
226{
227#ifdef FIPS_MODULE
228    int bitlen;
229
230    bitlen = BN_num_bits(e);
231    return (BN_is_odd(e) && bitlen > 16 && bitlen < 257);
232#else
233    /* Allow small exponents larger than 1 for legacy purposes */
234    return BN_is_odd(e) && BN_cmp(e, BN_value_one()) > 0;
235#endif /* FIPS_MODULE */
236}
237
238/*
239 * SP800-56Br1 6.4.1.2.1 (Step 5i): |p - q| > 2^(nbits/2 - 100)
240 * i.e- numbits(p-q-1) > (nbits/2 -100)
241 */
242int ossl_rsa_check_pminusq_diff(BIGNUM *diff, const BIGNUM *p, const BIGNUM *q,
243                           int nbits)
244{
245    int bitlen = (nbits >> 1) - 100;
246
247    if (!BN_sub(diff, p, q))
248        return -1;
249    BN_set_negative(diff, 0);
250
251    if (BN_is_zero(diff))
252        return 0;
253
254    if (!BN_sub_word(diff, 1))
255        return -1;
256    return (BN_num_bits(diff) > bitlen);
257}
258
259/*
260 * return LCM(p-1, q-1)
261 *
262 * Caller should ensure that lcm, gcd, p1, q1, p1q1 are flagged with
263 * BN_FLG_CONSTTIME.
264 */
265int ossl_rsa_get_lcm(BN_CTX *ctx, const BIGNUM *p, const BIGNUM *q,
266                     BIGNUM *lcm, BIGNUM *gcd, BIGNUM *p1, BIGNUM *q1,
267                     BIGNUM *p1q1)
268{
269    return BN_sub(p1, p, BN_value_one())    /* p-1 */
270           && BN_sub(q1, q, BN_value_one()) /* q-1 */
271           && BN_mul(p1q1, p1, q1, ctx)     /* (p-1)(q-1) */
272           && BN_gcd(gcd, p1, q1, ctx)
273           && BN_div(lcm, NULL, p1q1, gcd, ctx); /* LCM((p-1, q-1)) */
274}
275
276/*
277 * SP800-56Br1 6.4.2.2 Partial Public Key Validation for RSA refers to
278 * SP800-89 5.3.3 (Explicit) Partial Public Key Validation for RSA
279 * caveat is that the modulus must be as specified in SP800-56Br1
280 */
281int ossl_rsa_sp800_56b_check_public(const RSA *rsa)
282{
283    int ret = 0, status;
284    int nbits;
285    BN_CTX *ctx = NULL;
286    BIGNUM *gcd = NULL;
287
288    if (rsa->n == NULL || rsa->e == NULL)
289        return 0;
290
291    nbits = BN_num_bits(rsa->n);
292    if (nbits > OPENSSL_RSA_MAX_MODULUS_BITS) {
293        ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
294        return 0;
295    }
296
297#ifdef FIPS_MODULE
298    /*
299     * (Step a): modulus must be 2048 or 3072 (caveat from SP800-56Br1)
300     * NOTE: changed to allow keys >= 2048
301     */
302    if (!ossl_rsa_sp800_56b_validate_strength(nbits, -1)) {
303        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEY_LENGTH);
304        return 0;
305    }
306#endif
307    if (!BN_is_odd(rsa->n)) {
308        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
309        return 0;
310    }
311    /* (Steps b-c): 2^16 < e < 2^256, n and e must be odd */
312    if (!ossl_rsa_check_public_exponent(rsa->e)) {
313        ERR_raise(ERR_LIB_RSA, RSA_R_PUB_EXPONENT_OUT_OF_RANGE);
314        return 0;
315    }
316
317    ctx = BN_CTX_new_ex(rsa->libctx);
318    gcd = BN_new();
319    if (ctx == NULL || gcd == NULL)
320        goto err;
321
322    /* (Steps d-f):
323     * The modulus is composite, but not a power of a prime.
324     * The modulus has no factors smaller than 752.
325     */
326    if (!BN_gcd(gcd, rsa->n, ossl_bn_get0_small_factors(), ctx)
327        || !BN_is_one(gcd)) {
328        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
329        goto err;
330    }
331
332    /* Highest number of MR rounds from FIPS 186-5 Section B.3 Table B.1 */
333    ret = ossl_bn_miller_rabin_is_prime(rsa->n, 5, ctx, NULL, 1, &status);
334#ifdef FIPS_MODULE
335    if (ret != 1 || status != BN_PRIMETEST_COMPOSITE_NOT_POWER_OF_PRIME) {
336#else
337    if (ret != 1 || (status != BN_PRIMETEST_COMPOSITE_NOT_POWER_OF_PRIME
338                     && (nbits >= RSA_MIN_MODULUS_BITS
339                         || status != BN_PRIMETEST_COMPOSITE_WITH_FACTOR))) {
340#endif
341        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
342        ret = 0;
343        goto err;
344    }
345
346    ret = 1;
347err:
348    BN_free(gcd);
349    BN_CTX_free(ctx);
350    return ret;
351}
352
353/*
354 * Perform validation of the RSA private key to check that 0 < D < N.
355 */
356int ossl_rsa_sp800_56b_check_private(const RSA *rsa)
357{
358    if (rsa->d == NULL || rsa->n == NULL)
359        return 0;
360    return BN_cmp(rsa->d, BN_value_one()) >= 0 && BN_cmp(rsa->d, rsa->n) < 0;
361}
362
363/*
364 * RSA key pair validation.
365 *
366 * SP800-56Br1.
367 *    6.4.1.2 "RSAKPV1 Family: RSA Key - Pair Validation with a Fixed Exponent"
368 *    6.4.1.3 "RSAKPV2 Family: RSA Key - Pair Validation with a Random Exponent"
369 *
370 * It uses:
371 *     6.4.1.2.3 "rsakpv1 - crt"
372 *     6.4.1.3.3 "rsakpv2 - crt"
373 */
374int ossl_rsa_sp800_56b_check_keypair(const RSA *rsa, const BIGNUM *efixed,
375                                     int strength, int nbits)
376{
377    int ret = 0;
378    BN_CTX *ctx = NULL;
379    BIGNUM *r = NULL;
380
381    if (rsa->p == NULL
382            || rsa->q == NULL
383            || rsa->e == NULL
384            || rsa->d == NULL
385            || rsa->n == NULL) {
386        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
387        return 0;
388    }
389    /* (Step 1): Check Ranges */
390    if (!ossl_rsa_sp800_56b_validate_strength(nbits, strength))
391        return 0;
392
393    /* If the exponent is known */
394    if (efixed != NULL) {
395        /* (2): Check fixed exponent matches public exponent. */
396        if (BN_cmp(efixed, rsa->e) != 0) {
397            ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
398            return 0;
399        }
400    }
401    /* (Step 1.c): e is odd integer 65537 <= e < 2^256 */
402    if (!ossl_rsa_check_public_exponent(rsa->e)) {
403        /* exponent out of range */
404        ERR_raise(ERR_LIB_RSA, RSA_R_PUB_EXPONENT_OUT_OF_RANGE);
405        return 0;
406    }
407    /* (Step 3.b): check the modulus */
408    if (nbits != BN_num_bits(rsa->n)) {
409        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEYPAIR);
410        return 0;
411    }
412
413    ctx = BN_CTX_new_ex(rsa->libctx);
414    if (ctx == NULL)
415        return 0;
416
417    BN_CTX_start(ctx);
418    r = BN_CTX_get(ctx);
419    if (r == NULL || !BN_mul(r, rsa->p, rsa->q, ctx))
420        goto err;
421    /* (Step 4.c): Check n = pq */
422    if (BN_cmp(rsa->n, r) != 0) {
423        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
424        goto err;
425    }
426
427    /* (Step 5): check prime factors p & q */
428    ret = ossl_rsa_check_prime_factor(rsa->p, rsa->e, nbits, ctx)
429          && ossl_rsa_check_prime_factor(rsa->q, rsa->e, nbits, ctx)
430          && (ossl_rsa_check_pminusq_diff(r, rsa->p, rsa->q, nbits) > 0)
431          /* (Step 6): Check the private exponent d */
432          && ossl_rsa_check_private_exponent(rsa, nbits, ctx)
433          /* 6.4.1.2.3 (Step 7): Check the CRT components */
434          && ossl_rsa_check_crt_components(rsa, ctx);
435    if (ret != 1)
436        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEYPAIR);
437
438err:
439    BN_clear(r);
440    BN_CTX_end(ctx);
441    BN_CTX_free(ctx);
442    return ret;
443}
444