1/*
2 * Copyright 2020-2022 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License").  You may not use
5 * this file except in compliance with the License.  You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10#include <stdio.h>
11#include <stdlib.h>
12#include <openssl/objects.h>
13#include <openssl/evp.h>
14#include "internal/cryptlib.h"
15#include "internal/provider.h"
16#include "internal/core.h"
17#include "crypto/evp.h"
18#include "evp_local.h"
19
20static int evp_kem_init(EVP_PKEY_CTX *ctx, int operation,
21                        const OSSL_PARAM params[])
22{
23    int ret = 0;
24    EVP_KEM *kem = NULL;
25    EVP_KEYMGMT *tmp_keymgmt = NULL;
26    const OSSL_PROVIDER *tmp_prov = NULL;
27    void *provkey = NULL;
28    const char *supported_kem = NULL;
29    int iter;
30
31    if (ctx == NULL || ctx->keytype == NULL) {
32        ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
33        return 0;
34    }
35
36    evp_pkey_ctx_free_old_ops(ctx);
37    ctx->operation = operation;
38
39    if (ctx->pkey == NULL) {
40        ERR_raise(ERR_LIB_EVP, EVP_R_NO_KEY_SET);
41        goto err;
42    }
43
44    /*
45     * Try to derive the supported kem from |ctx->keymgmt|.
46     */
47    if (!ossl_assert(ctx->pkey->keymgmt == NULL
48                     || ctx->pkey->keymgmt == ctx->keymgmt)) {
49        ERR_raise(ERR_LIB_EVP, ERR_R_INTERNAL_ERROR);
50        goto err;
51    }
52    supported_kem = evp_keymgmt_util_query_operation_name(ctx->keymgmt,
53                                                          OSSL_OP_KEM);
54    if (supported_kem == NULL) {
55        ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
56        goto err;
57    }
58
59    /*
60     * Because we cleared out old ops, we shouldn't need to worry about
61     * checking if kem is already there.
62     * We perform two iterations:
63     *
64     * 1.  Do the normal kem fetch, using the fetching data given by
65     *     the EVP_PKEY_CTX.
66     * 2.  Do the provider specific kem fetch, from the same provider
67     *     as |ctx->keymgmt|
68     *
69     * We then try to fetch the keymgmt from the same provider as the
70     * kem, and try to export |ctx->pkey| to that keymgmt (when this
71     * keymgmt happens to be the same as |ctx->keymgmt|, the export is
72     * a no-op, but we call it anyway to not complicate the code even
73     * more).
74     * If the export call succeeds (returns a non-NULL provider key pointer),
75     * we're done and can perform the operation itself.  If not, we perform
76     * the second iteration, or jump to legacy.
77     */
78    for (iter = 1, provkey = NULL; iter < 3 && provkey == NULL; iter++) {
79        EVP_KEYMGMT *tmp_keymgmt_tofree = NULL;
80
81        /*
82         * If we're on the second iteration, free the results from the first.
83         * They are NULL on the first iteration, so no need to check what
84         * iteration we're on.
85         */
86        EVP_KEM_free(kem);
87        EVP_KEYMGMT_free(tmp_keymgmt);
88
89        switch (iter) {
90        case 1:
91            kem = EVP_KEM_fetch(ctx->libctx, supported_kem, ctx->propquery);
92            if (kem != NULL)
93                tmp_prov = EVP_KEM_get0_provider(kem);
94            break;
95        case 2:
96            tmp_prov = EVP_KEYMGMT_get0_provider(ctx->keymgmt);
97            kem = evp_kem_fetch_from_prov((OSSL_PROVIDER *)tmp_prov,
98                                          supported_kem, ctx->propquery);
99
100            if (kem == NULL) {
101                ERR_raise(ERR_LIB_EVP,
102                          EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
103                ret = -2;
104                goto err;
105            }
106        }
107        if (kem == NULL)
108            continue;
109
110        /*
111         * Ensure that the key is provided, either natively, or as a cached
112         * export.  We start by fetching the keymgmt with the same name as
113         * |ctx->pkey|, but from the provider of the kem method, using the
114         * same property query as when fetching the kem method.
115         * With the keymgmt we found (if we did), we try to export |ctx->pkey|
116         * to it (evp_pkey_export_to_provider() is smart enough to only actually
117
118         * export it if |tmp_keymgmt| is different from |ctx->pkey|'s keymgmt)
119         */
120        tmp_keymgmt_tofree = tmp_keymgmt =
121            evp_keymgmt_fetch_from_prov((OSSL_PROVIDER *)tmp_prov,
122                                        EVP_KEYMGMT_get0_name(ctx->keymgmt),
123                                        ctx->propquery);
124        if (tmp_keymgmt != NULL)
125            provkey = evp_pkey_export_to_provider(ctx->pkey, ctx->libctx,
126                                                  &tmp_keymgmt, ctx->propquery);
127        if (tmp_keymgmt == NULL)
128            EVP_KEYMGMT_free(tmp_keymgmt_tofree);
129    }
130
131    if (provkey == NULL) {
132        EVP_KEM_free(kem);
133        ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
134        goto err;
135    }
136
137    ctx->op.encap.kem = kem;
138    ctx->op.encap.algctx = kem->newctx(ossl_provider_ctx(kem->prov));
139    if (ctx->op.encap.algctx == NULL) {
140        /* The provider key can stay in the cache */
141        ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
142        goto err;
143    }
144
145    switch (operation) {
146    case EVP_PKEY_OP_ENCAPSULATE:
147        if (kem->encapsulate_init == NULL) {
148            ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
149            ret = -2;
150            goto err;
151        }
152        ret = kem->encapsulate_init(ctx->op.encap.algctx, provkey, params);
153        break;
154    case EVP_PKEY_OP_DECAPSULATE:
155        if (kem->decapsulate_init == NULL) {
156            ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
157            ret = -2;
158            goto err;
159        }
160        ret = kem->decapsulate_init(ctx->op.encap.algctx, provkey, params);
161        break;
162    default:
163        ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
164        goto err;
165    }
166
167    EVP_KEYMGMT_free(tmp_keymgmt);
168    tmp_keymgmt = NULL;
169
170    if (ret > 0)
171        return 1;
172 err:
173    if (ret <= 0) {
174        evp_pkey_ctx_free_old_ops(ctx);
175        ctx->operation = EVP_PKEY_OP_UNDEFINED;
176    }
177    EVP_KEYMGMT_free(tmp_keymgmt);
178    return ret;
179}
180
181int EVP_PKEY_encapsulate_init(EVP_PKEY_CTX *ctx, const OSSL_PARAM params[])
182{
183    return evp_kem_init(ctx, EVP_PKEY_OP_ENCAPSULATE, params);
184}
185
186int EVP_PKEY_encapsulate(EVP_PKEY_CTX *ctx,
187                         unsigned char *out, size_t *outlen,
188                         unsigned char *secret, size_t *secretlen)
189{
190    if (ctx == NULL)
191        return 0;
192
193    if (ctx->operation != EVP_PKEY_OP_ENCAPSULATE) {
194        ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_INITIALIZED);
195        return -1;
196    }
197
198    if (ctx->op.encap.algctx == NULL) {
199        ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
200        return -2;
201    }
202
203    if (out != NULL && secret == NULL)
204        return 0;
205
206    return ctx->op.encap.kem->encapsulate(ctx->op.encap.algctx,
207                                          out, outlen, secret, secretlen);
208}
209
210int EVP_PKEY_decapsulate_init(EVP_PKEY_CTX *ctx, const OSSL_PARAM params[])
211{
212    return evp_kem_init(ctx, EVP_PKEY_OP_DECAPSULATE, params);
213}
214
215int EVP_PKEY_decapsulate(EVP_PKEY_CTX *ctx,
216                         unsigned char *secret, size_t *secretlen,
217                         const unsigned char *in, size_t inlen)
218{
219    if (ctx == NULL
220        || (in == NULL || inlen == 0)
221        || (secret == NULL && secretlen == NULL))
222        return 0;
223
224    if (ctx->operation != EVP_PKEY_OP_DECAPSULATE) {
225        ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_INITIALIZED);
226        return -1;
227    }
228
229    if (ctx->op.encap.algctx == NULL) {
230        ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
231        return -2;
232    }
233    return ctx->op.encap.kem->decapsulate(ctx->op.encap.algctx,
234                                          secret, secretlen, in, inlen);
235}
236
237static EVP_KEM *evp_kem_new(OSSL_PROVIDER *prov)
238{
239    EVP_KEM *kem = OPENSSL_zalloc(sizeof(EVP_KEM));
240
241    if (kem == NULL) {
242        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
243        return NULL;
244    }
245
246    kem->lock = CRYPTO_THREAD_lock_new();
247    if (kem->lock == NULL) {
248        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
249        OPENSSL_free(kem);
250        return NULL;
251    }
252    kem->prov = prov;
253    ossl_provider_up_ref(prov);
254    kem->refcnt = 1;
255
256    return kem;
257}
258
259static void *evp_kem_from_algorithm(int name_id, const OSSL_ALGORITHM *algodef,
260                                    OSSL_PROVIDER *prov)
261{
262    const OSSL_DISPATCH *fns = algodef->implementation;
263    EVP_KEM *kem = NULL;
264    int ctxfncnt = 0, encfncnt = 0, decfncnt = 0;
265    int gparamfncnt = 0, sparamfncnt = 0;
266
267    if ((kem = evp_kem_new(prov)) == NULL) {
268        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
269        goto err;
270    }
271
272    kem->name_id = name_id;
273    if ((kem->type_name = ossl_algorithm_get1_first_name(algodef)) == NULL)
274        goto err;
275    kem->description = algodef->algorithm_description;
276
277    for (; fns->function_id != 0; fns++) {
278        switch (fns->function_id) {
279        case OSSL_FUNC_KEM_NEWCTX:
280            if (kem->newctx != NULL)
281                break;
282            kem->newctx = OSSL_FUNC_kem_newctx(fns);
283            ctxfncnt++;
284            break;
285        case OSSL_FUNC_KEM_ENCAPSULATE_INIT:
286            if (kem->encapsulate_init != NULL)
287                break;
288            kem->encapsulate_init = OSSL_FUNC_kem_encapsulate_init(fns);
289            encfncnt++;
290            break;
291        case OSSL_FUNC_KEM_ENCAPSULATE:
292            if (kem->encapsulate != NULL)
293                break;
294            kem->encapsulate = OSSL_FUNC_kem_encapsulate(fns);
295            encfncnt++;
296            break;
297        case OSSL_FUNC_KEM_DECAPSULATE_INIT:
298            if (kem->decapsulate_init != NULL)
299                break;
300            kem->decapsulate_init = OSSL_FUNC_kem_decapsulate_init(fns);
301            decfncnt++;
302            break;
303        case OSSL_FUNC_KEM_DECAPSULATE:
304            if (kem->decapsulate != NULL)
305                break;
306            kem->decapsulate = OSSL_FUNC_kem_decapsulate(fns);
307            decfncnt++;
308            break;
309        case OSSL_FUNC_KEM_FREECTX:
310            if (kem->freectx != NULL)
311                break;
312            kem->freectx = OSSL_FUNC_kem_freectx(fns);
313            ctxfncnt++;
314            break;
315        case OSSL_FUNC_KEM_DUPCTX:
316            if (kem->dupctx != NULL)
317                break;
318            kem->dupctx = OSSL_FUNC_kem_dupctx(fns);
319            break;
320        case OSSL_FUNC_KEM_GET_CTX_PARAMS:
321            if (kem->get_ctx_params != NULL)
322                break;
323            kem->get_ctx_params
324                = OSSL_FUNC_kem_get_ctx_params(fns);
325            gparamfncnt++;
326            break;
327        case OSSL_FUNC_KEM_GETTABLE_CTX_PARAMS:
328            if (kem->gettable_ctx_params != NULL)
329                break;
330            kem->gettable_ctx_params
331                = OSSL_FUNC_kem_gettable_ctx_params(fns);
332            gparamfncnt++;
333            break;
334        case OSSL_FUNC_KEM_SET_CTX_PARAMS:
335            if (kem->set_ctx_params != NULL)
336                break;
337            kem->set_ctx_params
338                = OSSL_FUNC_kem_set_ctx_params(fns);
339            sparamfncnt++;
340            break;
341        case OSSL_FUNC_KEM_SETTABLE_CTX_PARAMS:
342            if (kem->settable_ctx_params != NULL)
343                break;
344            kem->settable_ctx_params
345                = OSSL_FUNC_kem_settable_ctx_params(fns);
346            sparamfncnt++;
347            break;
348        }
349    }
350    if (ctxfncnt != 2
351        || (encfncnt != 0 && encfncnt != 2)
352        || (decfncnt != 0 && decfncnt != 2)
353        || (encfncnt != 2 && decfncnt != 2)
354        || (gparamfncnt != 0 && gparamfncnt != 2)
355        || (sparamfncnt != 0 && sparamfncnt != 2)) {
356        /*
357         * In order to be a consistent set of functions we must have at least
358         * a set of context functions (newctx and freectx) as well as a pair of
359         * "kem" functions: (encapsulate_init, encapsulate) or
360         * (decapsulate_init, decapsulate). set_ctx_params and settable_ctx_params are
361         * optional, but if one of them is present then the other one must also
362         * be present. The same applies to get_ctx_params and
363         * gettable_ctx_params. The dupctx function is optional.
364         */
365        ERR_raise(ERR_LIB_EVP, EVP_R_INVALID_PROVIDER_FUNCTIONS);
366        goto err;
367    }
368
369    return kem;
370 err:
371    EVP_KEM_free(kem);
372    return NULL;
373}
374
375void EVP_KEM_free(EVP_KEM *kem)
376{
377    int i;
378
379    if (kem == NULL)
380        return;
381
382    CRYPTO_DOWN_REF(&kem->refcnt, &i, kem->lock);
383    if (i > 0)
384        return;
385    OPENSSL_free(kem->type_name);
386    ossl_provider_free(kem->prov);
387    CRYPTO_THREAD_lock_free(kem->lock);
388    OPENSSL_free(kem);
389}
390
391int EVP_KEM_up_ref(EVP_KEM *kem)
392{
393    int ref = 0;
394
395    CRYPTO_UP_REF(&kem->refcnt, &ref, kem->lock);
396    return 1;
397}
398
399OSSL_PROVIDER *EVP_KEM_get0_provider(const EVP_KEM *kem)
400{
401    return kem->prov;
402}
403
404EVP_KEM *EVP_KEM_fetch(OSSL_LIB_CTX *ctx, const char *algorithm,
405                       const char *properties)
406{
407    return evp_generic_fetch(ctx, OSSL_OP_KEM, algorithm, properties,
408                             evp_kem_from_algorithm,
409                             (int (*)(void *))EVP_KEM_up_ref,
410                             (void (*)(void *))EVP_KEM_free);
411}
412
413EVP_KEM *evp_kem_fetch_from_prov(OSSL_PROVIDER *prov, const char *algorithm,
414                                 const char *properties)
415{
416    return evp_generic_fetch_from_prov(prov, OSSL_OP_KEM, algorithm, properties,
417                                       evp_kem_from_algorithm,
418                                       (int (*)(void *))EVP_KEM_up_ref,
419                                       (void (*)(void *))EVP_KEM_free);
420}
421
422int EVP_KEM_is_a(const EVP_KEM *kem, const char *name)
423{
424    return kem != NULL && evp_is_a(kem->prov, kem->name_id, NULL, name);
425}
426
427int evp_kem_get_number(const EVP_KEM *kem)
428{
429    return kem->name_id;
430}
431
432const char *EVP_KEM_get0_name(const EVP_KEM *kem)
433{
434    return kem->type_name;
435}
436
437const char *EVP_KEM_get0_description(const EVP_KEM *kem)
438{
439    return kem->description;
440}
441
442void EVP_KEM_do_all_provided(OSSL_LIB_CTX *libctx,
443                             void (*fn)(EVP_KEM *kem, void *arg),
444                             void *arg)
445{
446    evp_generic_do_all(libctx, OSSL_OP_KEM, (void (*)(void *, void *))fn, arg,
447                       evp_kem_from_algorithm,
448                       (int (*)(void *))EVP_KEM_up_ref,
449                       (void (*)(void *))EVP_KEM_free);
450}
451
452int EVP_KEM_names_do_all(const EVP_KEM *kem,
453                         void (*fn)(const char *name, void *data),
454                         void *data)
455{
456    if (kem->prov != NULL)
457        return evp_names_do_all(kem->prov, kem->name_id, fn, data);
458
459    return 1;
460}
461
462const OSSL_PARAM *EVP_KEM_gettable_ctx_params(const EVP_KEM *kem)
463{
464    void *provctx;
465
466    if (kem == NULL || kem->gettable_ctx_params == NULL)
467        return NULL;
468
469    provctx = ossl_provider_ctx(EVP_KEM_get0_provider(kem));
470    return kem->gettable_ctx_params(NULL, provctx);
471}
472
473const OSSL_PARAM *EVP_KEM_settable_ctx_params(const EVP_KEM *kem)
474{
475    void *provctx;
476
477    if (kem == NULL || kem->settable_ctx_params == NULL)
478        return NULL;
479
480    provctx = ossl_provider_ctx(EVP_KEM_get0_provider(kem));
481    return kem->settable_ctx_params(NULL, provctx);
482}
483