1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * SM3 using the RISC-V vector crypto extensions
4 *
5 * Copyright (C) 2023 VRULL GmbH
6 * Author: Heiko Stuebner <heiko.stuebner@vrull.eu>
7 *
8 * Copyright (C) 2023 SiFive, Inc.
9 * Author: Jerry Shih <jerry.shih@sifive.com>
10 */
11
12#include <asm/simd.h>
13#include <asm/vector.h>
14#include <crypto/internal/hash.h>
15#include <crypto/internal/simd.h>
16#include <crypto/sm3_base.h>
17#include <linux/linkage.h>
18#include <linux/module.h>
19
20/*
21 * Note: the asm function only uses the 'state' field of struct sm3_state.
22 * It is assumed to be the first field.
23 */
24asmlinkage void sm3_transform_zvksh_zvkb(
25	struct sm3_state *state, const u8 *data, int num_blocks);
26
27static int riscv64_sm3_update(struct shash_desc *desc, const u8 *data,
28			      unsigned int len)
29{
30	/*
31	 * Ensure struct sm3_state begins directly with the SM3
32	 * 256-bit internal state, as this is what the asm function expects.
33	 */
34	BUILD_BUG_ON(offsetof(struct sm3_state, state) != 0);
35
36	if (crypto_simd_usable()) {
37		kernel_vector_begin();
38		sm3_base_do_update(desc, data, len, sm3_transform_zvksh_zvkb);
39		kernel_vector_end();
40	} else {
41		sm3_update(shash_desc_ctx(desc), data, len);
42	}
43	return 0;
44}
45
46static int riscv64_sm3_finup(struct shash_desc *desc, const u8 *data,
47			     unsigned int len, u8 *out)
48{
49	struct sm3_state *ctx;
50
51	if (crypto_simd_usable()) {
52		kernel_vector_begin();
53		if (len)
54			sm3_base_do_update(desc, data, len,
55					   sm3_transform_zvksh_zvkb);
56		sm3_base_do_finalize(desc, sm3_transform_zvksh_zvkb);
57		kernel_vector_end();
58
59		return sm3_base_finish(desc, out);
60	}
61
62	ctx = shash_desc_ctx(desc);
63	if (len)
64		sm3_update(ctx, data, len);
65	sm3_final(ctx, out);
66
67	return 0;
68}
69
70static int riscv64_sm3_final(struct shash_desc *desc, u8 *out)
71{
72	return riscv64_sm3_finup(desc, NULL, 0, out);
73}
74
75static struct shash_alg riscv64_sm3_alg = {
76	.init = sm3_base_init,
77	.update = riscv64_sm3_update,
78	.final = riscv64_sm3_final,
79	.finup = riscv64_sm3_finup,
80	.descsize = sizeof(struct sm3_state),
81	.digestsize = SM3_DIGEST_SIZE,
82	.base = {
83		.cra_blocksize = SM3_BLOCK_SIZE,
84		.cra_priority = 300,
85		.cra_name = "sm3",
86		.cra_driver_name = "sm3-riscv64-zvksh-zvkb",
87		.cra_module = THIS_MODULE,
88	},
89};
90
91static int __init riscv64_sm3_mod_init(void)
92{
93	if (riscv_isa_extension_available(NULL, ZVKSH) &&
94	    riscv_isa_extension_available(NULL, ZVKB) &&
95	    riscv_vector_vlen() >= 128)
96		return crypto_register_shash(&riscv64_sm3_alg);
97
98	return -ENODEV;
99}
100
101static void __exit riscv64_sm3_mod_exit(void)
102{
103	crypto_unregister_shash(&riscv64_sm3_alg);
104}
105
106module_init(riscv64_sm3_mod_init);
107module_exit(riscv64_sm3_mod_exit);
108
109MODULE_DESCRIPTION("SM3 (RISC-V accelerated)");
110MODULE_AUTHOR("Heiko Stuebner <heiko.stuebner@vrull.eu>");
111MODULE_LICENSE("GPL");
112MODULE_ALIAS_CRYPTO("sm3");
113