1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Cryptographic API.
4 *
5 * Copyright (c) 2017-present, Facebook, Inc.
6 */
7#include <linux/crypto.h>
8#include <linux/init.h>
9#include <linux/interrupt.h>
10#include <linux/mm.h>
11#include <linux/module.h>
12#include <linux/net.h>
13#include <linux/vmalloc.h>
14#include <linux/zstd.h>
15#include <crypto/internal/scompress.h>
16
17
18#define ZSTD_DEF_LEVEL	3
19
20struct zstd_ctx {
21	zstd_cctx *cctx;
22	zstd_dctx *dctx;
23	void *cwksp;
24	void *dwksp;
25};
26
27static zstd_parameters zstd_params(void)
28{
29	return zstd_get_params(ZSTD_DEF_LEVEL, 0);
30}
31
32static int zstd_comp_init(struct zstd_ctx *ctx)
33{
34	int ret = 0;
35	const zstd_parameters params = zstd_params();
36	const size_t wksp_size = zstd_cctx_workspace_bound(&params.cParams);
37
38	ctx->cwksp = vzalloc(wksp_size);
39	if (!ctx->cwksp) {
40		ret = -ENOMEM;
41		goto out;
42	}
43
44	ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
45	if (!ctx->cctx) {
46		ret = -EINVAL;
47		goto out_free;
48	}
49out:
50	return ret;
51out_free:
52	vfree(ctx->cwksp);
53	goto out;
54}
55
56static int zstd_decomp_init(struct zstd_ctx *ctx)
57{
58	int ret = 0;
59	const size_t wksp_size = zstd_dctx_workspace_bound();
60
61	ctx->dwksp = vzalloc(wksp_size);
62	if (!ctx->dwksp) {
63		ret = -ENOMEM;
64		goto out;
65	}
66
67	ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
68	if (!ctx->dctx) {
69		ret = -EINVAL;
70		goto out_free;
71	}
72out:
73	return ret;
74out_free:
75	vfree(ctx->dwksp);
76	goto out;
77}
78
79static void zstd_comp_exit(struct zstd_ctx *ctx)
80{
81	vfree(ctx->cwksp);
82	ctx->cwksp = NULL;
83	ctx->cctx = NULL;
84}
85
86static void zstd_decomp_exit(struct zstd_ctx *ctx)
87{
88	vfree(ctx->dwksp);
89	ctx->dwksp = NULL;
90	ctx->dctx = NULL;
91}
92
93static int __zstd_init(void *ctx)
94{
95	int ret;
96
97	ret = zstd_comp_init(ctx);
98	if (ret)
99		return ret;
100	ret = zstd_decomp_init(ctx);
101	if (ret)
102		zstd_comp_exit(ctx);
103	return ret;
104}
105
106static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
107{
108	int ret;
109	struct zstd_ctx *ctx;
110
111	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
112	if (!ctx)
113		return ERR_PTR(-ENOMEM);
114
115	ret = __zstd_init(ctx);
116	if (ret) {
117		kfree(ctx);
118		return ERR_PTR(ret);
119	}
120
121	return ctx;
122}
123
124static int zstd_init(struct crypto_tfm *tfm)
125{
126	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
127
128	return __zstd_init(ctx);
129}
130
131static void __zstd_exit(void *ctx)
132{
133	zstd_comp_exit(ctx);
134	zstd_decomp_exit(ctx);
135}
136
137static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
138{
139	__zstd_exit(ctx);
140	kfree_sensitive(ctx);
141}
142
143static void zstd_exit(struct crypto_tfm *tfm)
144{
145	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
146
147	__zstd_exit(ctx);
148}
149
150static int __zstd_compress(const u8 *src, unsigned int slen,
151			   u8 *dst, unsigned int *dlen, void *ctx)
152{
153	size_t out_len;
154	struct zstd_ctx *zctx = ctx;
155	const zstd_parameters params = zstd_params();
156
157	out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, &params);
158	if (zstd_is_error(out_len))
159		return -EINVAL;
160	*dlen = out_len;
161	return 0;
162}
163
164static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
165			 unsigned int slen, u8 *dst, unsigned int *dlen)
166{
167	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
168
169	return __zstd_compress(src, slen, dst, dlen, ctx);
170}
171
172static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
173			  unsigned int slen, u8 *dst, unsigned int *dlen,
174			  void *ctx)
175{
176	return __zstd_compress(src, slen, dst, dlen, ctx);
177}
178
179static int __zstd_decompress(const u8 *src, unsigned int slen,
180			     u8 *dst, unsigned int *dlen, void *ctx)
181{
182	size_t out_len;
183	struct zstd_ctx *zctx = ctx;
184
185	out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
186	if (zstd_is_error(out_len))
187		return -EINVAL;
188	*dlen = out_len;
189	return 0;
190}
191
192static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
193			   unsigned int slen, u8 *dst, unsigned int *dlen)
194{
195	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
196
197	return __zstd_decompress(src, slen, dst, dlen, ctx);
198}
199
200static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
201			    unsigned int slen, u8 *dst, unsigned int *dlen,
202			    void *ctx)
203{
204	return __zstd_decompress(src, slen, dst, dlen, ctx);
205}
206
207static struct crypto_alg alg = {
208	.cra_name		= "zstd",
209	.cra_driver_name	= "zstd-generic",
210	.cra_flags		= CRYPTO_ALG_TYPE_COMPRESS,
211	.cra_ctxsize		= sizeof(struct zstd_ctx),
212	.cra_module		= THIS_MODULE,
213	.cra_init		= zstd_init,
214	.cra_exit		= zstd_exit,
215	.cra_u			= { .compress = {
216	.coa_compress		= zstd_compress,
217	.coa_decompress		= zstd_decompress } }
218};
219
220static struct scomp_alg scomp = {
221	.alloc_ctx		= zstd_alloc_ctx,
222	.free_ctx		= zstd_free_ctx,
223	.compress		= zstd_scompress,
224	.decompress		= zstd_sdecompress,
225	.base			= {
226		.cra_name	= "zstd",
227		.cra_driver_name = "zstd-scomp",
228		.cra_module	 = THIS_MODULE,
229	}
230};
231
232static int __init zstd_mod_init(void)
233{
234	int ret;
235
236	ret = crypto_register_alg(&alg);
237	if (ret)
238		return ret;
239
240	ret = crypto_register_scomp(&scomp);
241	if (ret)
242		crypto_unregister_alg(&alg);
243
244	return ret;
245}
246
247static void __exit zstd_mod_fini(void)
248{
249	crypto_unregister_alg(&alg);
250	crypto_unregister_scomp(&scomp);
251}
252
253subsys_initcall(zstd_mod_init);
254module_exit(zstd_mod_fini);
255
256MODULE_LICENSE("GPL");
257MODULE_DESCRIPTION("Zstd Compression Algorithm");
258MODULE_ALIAS_CRYPTO("zstd");
259