1255736Sdavidch// SPDX-License-Identifier: GPL-2.0-only 2255736Sdavidch/* 3255736Sdavidch * AES using the RISC-V vector crypto extensions. Includes the bare block 4255736Sdavidch * cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes. 5255736Sdavidch * 6255736Sdavidch * Copyright (C) 2023 VRULL GmbH 7255736Sdavidch * Author: Heiko Stuebner <heiko.stuebner@vrull.eu> 8255736Sdavidch * 9255736Sdavidch * Copyright (C) 2023 SiFive, Inc. 10255736Sdavidch * Author: Jerry Shih <jerry.shih@sifive.com> 11255736Sdavidch * 12255736Sdavidch * Copyright 2024 Google LLC 13255736Sdavidch */ 14255736Sdavidch 15255736Sdavidch#include <asm/simd.h> 16255736Sdavidch#include <asm/vector.h> 17255736Sdavidch#include <crypto/aes.h> 18255736Sdavidch#include <crypto/internal/cipher.h> 19255736Sdavidch#include <crypto/internal/simd.h> 20255736Sdavidch#include <crypto/internal/skcipher.h> 21255736Sdavidch#include <crypto/scatterwalk.h> 22255736Sdavidch#include <crypto/xts.h> 23255736Sdavidch#include <linux/linkage.h> 24255736Sdavidch#include <linux/module.h> 25255736Sdavidch 26255736Sdavidchasmlinkage void aes_encrypt_zvkned(const struct crypto_aes_ctx *key, 27255736Sdavidch const u8 in[AES_BLOCK_SIZE], 28255736Sdavidch u8 out[AES_BLOCK_SIZE]); 29255736Sdavidchasmlinkage void aes_decrypt_zvkned(const struct crypto_aes_ctx *key, 30255736Sdavidch const u8 in[AES_BLOCK_SIZE], 31255736Sdavidch u8 out[AES_BLOCK_SIZE]); 32255736Sdavidch 33255736Sdavidchasmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key, 34255736Sdavidch const u8 *in, u8 *out, size_t len); 35255736Sdavidchasmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key, 36255736Sdavidch const u8 *in, u8 *out, size_t len); 37255736Sdavidch 38255736Sdavidchasmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key, 39255736Sdavidch const u8 *in, u8 *out, size_t len, 40255736Sdavidch u8 iv[AES_BLOCK_SIZE]); 41255736Sdavidchasmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key, 42255736Sdavidch const u8 *in, u8 *out, size_t len, 43255736Sdavidch u8 iv[AES_BLOCK_SIZE]); 44255736Sdavidch 45255736Sdavidchasmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key, 46255736Sdavidch const u8 *in, u8 *out, size_t len, 47255736Sdavidch const u8 iv[AES_BLOCK_SIZE], bool enc); 48255736Sdavidch 49255736Sdavidchasmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key, 50255736Sdavidch const u8 *in, u8 *out, size_t len, 51255736Sdavidch u8 iv[AES_BLOCK_SIZE]); 52255736Sdavidch 53255736Sdavidchasmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg( 54255736Sdavidch const struct crypto_aes_ctx *key, 55255736Sdavidch const u8 *in, u8 *out, size_t len, 56255736Sdavidch u8 tweak[AES_BLOCK_SIZE]); 57255736Sdavidch 58255736Sdavidchasmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg( 59255736Sdavidch const struct crypto_aes_ctx *key, 60255736Sdavidch const u8 *in, u8 *out, size_t len, 61255736Sdavidch u8 tweak[AES_BLOCK_SIZE]); 62255736Sdavidch 63255736Sdavidchstatic int riscv64_aes_setkey(struct crypto_aes_ctx *ctx, 64255736Sdavidch const u8 *key, unsigned int keylen) 65255736Sdavidch{ 66255736Sdavidch /* 67255736Sdavidch * For now we just use the generic key expansion, for these reasons: 68255736Sdavidch * 69255736Sdavidch * - zvkned's key expansion instructions don't support AES-192. 70255736Sdavidch * So, non-zvkned fallback code would be needed anyway. 71255736Sdavidch * 72255736Sdavidch * - Users of AES in Linux usually don't change keys frequently. 73255736Sdavidch * So, key expansion isn't performance-critical. 74255736Sdavidch * 75255736Sdavidch * - For single-block AES exposed as a "cipher" algorithm, it's 76255736Sdavidch * necessary to use struct crypto_aes_ctx and initialize its 'key_dec' 77255736Sdavidch * field with the round keys for the Equivalent Inverse Cipher. This 78255736Sdavidch * is because with "cipher", decryption can be requested from a 79255736Sdavidch * context where the vector unit isn't usable, necessitating a 80255736Sdavidch * fallback to aes_decrypt(). But, zvkned can only generate and use 81255736Sdavidch * the normal round keys. Of course, it's preferable to not have 82255736Sdavidch * special code just for "cipher", as e.g. XTS also uses a 83255736Sdavidch * single-block AES encryption. It's simplest to just use 84255736Sdavidch * struct crypto_aes_ctx and aes_expandkey() everywhere. 85255736Sdavidch */ 86255736Sdavidch return aes_expandkey(ctx, key, keylen); 87255736Sdavidch} 88255736Sdavidch 89255736Sdavidchstatic int riscv64_aes_setkey_cipher(struct crypto_tfm *tfm, 90255736Sdavidch const u8 *key, unsigned int keylen) 91255736Sdavidch{ 92255736Sdavidch struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm); 93255736Sdavidch 94255736Sdavidch return riscv64_aes_setkey(ctx, key, keylen); 95255736Sdavidch} 96255736Sdavidch 97255736Sdavidchstatic int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm, 98255736Sdavidch const u8 *key, unsigned int keylen) 99255736Sdavidch{ 100255736Sdavidch struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 101255736Sdavidch 102255736Sdavidch return riscv64_aes_setkey(ctx, key, keylen); 103255736Sdavidch} 104255736Sdavidch 105255736Sdavidch/* Bare AES, without a mode of operation */ 106255736Sdavidch 107255736Sdavidchstatic void riscv64_aes_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src) 108255736Sdavidch{ 109255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm); 110255736Sdavidch 111255736Sdavidch if (crypto_simd_usable()) { 112255736Sdavidch kernel_vector_begin(); 113255736Sdavidch aes_encrypt_zvkned(ctx, src, dst); 114255736Sdavidch kernel_vector_end(); 115255736Sdavidch } else { 116255736Sdavidch aes_encrypt(ctx, dst, src); 117255736Sdavidch } 118255736Sdavidch} 119255736Sdavidch 120255736Sdavidchstatic void riscv64_aes_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src) 121255736Sdavidch{ 122255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm); 123255736Sdavidch 124255736Sdavidch if (crypto_simd_usable()) { 125255736Sdavidch kernel_vector_begin(); 126255736Sdavidch aes_decrypt_zvkned(ctx, src, dst); 127255736Sdavidch kernel_vector_end(); 128255736Sdavidch } else { 129255736Sdavidch aes_decrypt(ctx, dst, src); 130255736Sdavidch } 131255736Sdavidch} 132255736Sdavidch 133255736Sdavidch/* AES-ECB */ 134255736Sdavidch 135255736Sdavidchstatic inline int riscv64_aes_ecb_crypt(struct skcipher_request *req, bool enc) 136255736Sdavidch{ 137255736Sdavidch struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 138255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 139255736Sdavidch struct skcipher_walk walk; 140255736Sdavidch unsigned int nbytes; 141255736Sdavidch int err; 142255736Sdavidch 143255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 144255736Sdavidch while ((nbytes = walk.nbytes) != 0) { 145255736Sdavidch kernel_vector_begin(); 146255736Sdavidch if (enc) 147255736Sdavidch aes_ecb_encrypt_zvkned(ctx, walk.src.virt.addr, 148255736Sdavidch walk.dst.virt.addr, 149255736Sdavidch nbytes & ~(AES_BLOCK_SIZE - 1)); 150255736Sdavidch else 151255736Sdavidch aes_ecb_decrypt_zvkned(ctx, walk.src.virt.addr, 152255736Sdavidch walk.dst.virt.addr, 153255736Sdavidch nbytes & ~(AES_BLOCK_SIZE - 1)); 154255736Sdavidch kernel_vector_end(); 155255736Sdavidch err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1)); 156255736Sdavidch } 157255736Sdavidch 158255736Sdavidch return err; 159255736Sdavidch} 160255736Sdavidch 161255736Sdavidchstatic int riscv64_aes_ecb_encrypt(struct skcipher_request *req) 162255736Sdavidch{ 163255736Sdavidch return riscv64_aes_ecb_crypt(req, true); 164255736Sdavidch} 165255736Sdavidch 166255736Sdavidchstatic int riscv64_aes_ecb_decrypt(struct skcipher_request *req) 167255736Sdavidch{ 168255736Sdavidch return riscv64_aes_ecb_crypt(req, false); 169255736Sdavidch} 170255736Sdavidch 171255736Sdavidch/* AES-CBC */ 172255736Sdavidch 173255736Sdavidchstatic int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc) 174255736Sdavidch{ 175255736Sdavidch struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 176255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 177255736Sdavidch struct skcipher_walk walk; 178255736Sdavidch unsigned int nbytes; 179255736Sdavidch int err; 180255736Sdavidch 181255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 182255736Sdavidch while ((nbytes = walk.nbytes) != 0) { 183255736Sdavidch kernel_vector_begin(); 184255736Sdavidch if (enc) 185255736Sdavidch aes_cbc_encrypt_zvkned(ctx, walk.src.virt.addr, 186255736Sdavidch walk.dst.virt.addr, 187255736Sdavidch nbytes & ~(AES_BLOCK_SIZE - 1), 188255736Sdavidch walk.iv); 189255736Sdavidch else 190255736Sdavidch aes_cbc_decrypt_zvkned(ctx, walk.src.virt.addr, 191255736Sdavidch walk.dst.virt.addr, 192255736Sdavidch nbytes & ~(AES_BLOCK_SIZE - 1), 193255736Sdavidch walk.iv); 194255736Sdavidch kernel_vector_end(); 195255736Sdavidch err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1)); 196255736Sdavidch } 197255736Sdavidch 198255736Sdavidch return err; 199255736Sdavidch} 200255736Sdavidch 201255736Sdavidchstatic int riscv64_aes_cbc_encrypt(struct skcipher_request *req) 202255736Sdavidch{ 203255736Sdavidch return riscv64_aes_cbc_crypt(req, true); 204255736Sdavidch} 205255736Sdavidch 206255736Sdavidchstatic int riscv64_aes_cbc_decrypt(struct skcipher_request *req) 207255736Sdavidch{ 208255736Sdavidch return riscv64_aes_cbc_crypt(req, false); 209255736Sdavidch} 210255736Sdavidch 211255736Sdavidch/* AES-CBC-CTS */ 212255736Sdavidch 213255736Sdavidchstatic int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc) 214255736Sdavidch{ 215255736Sdavidch struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 216255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 217255736Sdavidch struct scatterlist sg_src[2], sg_dst[2]; 218255736Sdavidch struct skcipher_request subreq; 219255736Sdavidch struct scatterlist *src, *dst; 220255736Sdavidch struct skcipher_walk walk; 221255736Sdavidch unsigned int cbc_len; 222255736Sdavidch int err; 223255736Sdavidch 224255736Sdavidch if (req->cryptlen < AES_BLOCK_SIZE) 225255736Sdavidch return -EINVAL; 226255736Sdavidch 227255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 228255736Sdavidch if (err) 229255736Sdavidch return err; 230255736Sdavidch /* 231255736Sdavidch * If the full message is available in one step, decrypt it in one call 232255736Sdavidch * to the CBC-CTS assembly function. This reduces overhead, especially 233255736Sdavidch * on short messages. Otherwise, fall back to doing CBC up to the last 234255736Sdavidch * two blocks, then invoke CTS just for the ciphertext stealing. 235255736Sdavidch */ 236255736Sdavidch if (unlikely(walk.nbytes != req->cryptlen)) { 237255736Sdavidch cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1, 238255736Sdavidch AES_BLOCK_SIZE); 239255736Sdavidch skcipher_walk_abort(&walk); 240255736Sdavidch skcipher_request_set_tfm(&subreq, tfm); 241255736Sdavidch skcipher_request_set_callback(&subreq, 242255736Sdavidch skcipher_request_flags(req), 243255736Sdavidch NULL, NULL); 244255736Sdavidch skcipher_request_set_crypt(&subreq, req->src, req->dst, 245255736Sdavidch cbc_len, req->iv); 246255736Sdavidch err = riscv64_aes_cbc_crypt(&subreq, enc); 247255736Sdavidch if (err) 248255736Sdavidch return err; 249255736Sdavidch dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len); 250255736Sdavidch if (req->dst != req->src) 251255736Sdavidch dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len); 252255736Sdavidch skcipher_request_set_crypt(&subreq, src, dst, 253255736Sdavidch req->cryptlen - cbc_len, req->iv); 254255736Sdavidch err = skcipher_walk_virt(&walk, &subreq, false); 255255736Sdavidch if (err) 256255736Sdavidch return err; 257255736Sdavidch } 258255736Sdavidch kernel_vector_begin(); 259255736Sdavidch aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr, 260255736Sdavidch walk.nbytes, req->iv, enc); 261255736Sdavidch kernel_vector_end(); 262255736Sdavidch return skcipher_walk_done(&walk, 0); 263255736Sdavidch} 264255736Sdavidch 265255736Sdavidchstatic int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req) 266255736Sdavidch{ 267255736Sdavidch return riscv64_aes_cbc_cts_crypt(req, true); 268255736Sdavidch} 269255736Sdavidch 270255736Sdavidchstatic int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req) 271255736Sdavidch{ 272255736Sdavidch return riscv64_aes_cbc_cts_crypt(req, false); 273255736Sdavidch} 274255736Sdavidch 275255736Sdavidch/* AES-CTR */ 276255736Sdavidch 277255736Sdavidchstatic int riscv64_aes_ctr_crypt(struct skcipher_request *req) 278255736Sdavidch{ 279255736Sdavidch struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 280255736Sdavidch const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm); 281255736Sdavidch unsigned int nbytes, p1_nbytes; 282255736Sdavidch struct skcipher_walk walk; 283255736Sdavidch u32 ctr32, nblocks; 284255736Sdavidch int err; 285255736Sdavidch 286255736Sdavidch /* Get the low 32-bit word of the 128-bit big endian counter. */ 287255736Sdavidch ctr32 = get_unaligned_be32(req->iv + 12); 288255736Sdavidch 289255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 290255736Sdavidch while ((nbytes = walk.nbytes) != 0) { 291255736Sdavidch if (nbytes < walk.total) { 292255736Sdavidch /* Not the end yet, so keep the length block-aligned. */ 293255736Sdavidch nbytes = round_down(nbytes, AES_BLOCK_SIZE); 294255736Sdavidch nblocks = nbytes / AES_BLOCK_SIZE; 295255736Sdavidch } else { 296255736Sdavidch /* It's the end, so include any final partial block. */ 297255736Sdavidch nblocks = DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE); 298255736Sdavidch } 299255736Sdavidch ctr32 += nblocks; 300255736Sdavidch 301255736Sdavidch kernel_vector_begin(); 302255736Sdavidch if (ctr32 >= nblocks) { 303255736Sdavidch /* The low 32-bit word of the counter won't overflow. */ 304255736Sdavidch aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr, 305255736Sdavidch walk.dst.virt.addr, nbytes, 306255736Sdavidch req->iv); 307255736Sdavidch } else { 308255736Sdavidch /* 309255736Sdavidch * The low 32-bit word of the counter will overflow. 310255736Sdavidch * The assembly doesn't handle this case, so split the 311255736Sdavidch * operation into two at the point where the overflow 312255736Sdavidch * will occur. After the first part, add the carry bit. 313255736Sdavidch */ 314255736Sdavidch p1_nbytes = min_t(unsigned int, nbytes, 315255736Sdavidch (nblocks - ctr32) * AES_BLOCK_SIZE); 316255736Sdavidch aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr, 317255736Sdavidch walk.dst.virt.addr, 318255736Sdavidch p1_nbytes, req->iv); 319255736Sdavidch crypto_inc(req->iv, 12); 320255736Sdavidch 321255736Sdavidch if (ctr32) { 322255736Sdavidch aes_ctr32_crypt_zvkned_zvkb( 323255736Sdavidch ctx, 324255736Sdavidch walk.src.virt.addr + p1_nbytes, 325255736Sdavidch walk.dst.virt.addr + p1_nbytes, 326255736Sdavidch nbytes - p1_nbytes, req->iv); 327255736Sdavidch } 328255736Sdavidch } 329255736Sdavidch kernel_vector_end(); 330255736Sdavidch 331255736Sdavidch err = skcipher_walk_done(&walk, walk.nbytes - nbytes); 332255736Sdavidch } 333255736Sdavidch 334255736Sdavidch return err; 335255736Sdavidch} 336255736Sdavidch 337255736Sdavidch/* AES-XTS */ 338255736Sdavidch 339255736Sdavidchstruct riscv64_aes_xts_ctx { 340255736Sdavidch struct crypto_aes_ctx ctx1; 341255736Sdavidch struct crypto_aes_ctx ctx2; 342255736Sdavidch}; 343255736Sdavidch 344255736Sdavidchstatic int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key, 345255736Sdavidch unsigned int keylen) 346255736Sdavidch{ 347255736Sdavidch struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 348255736Sdavidch 349255736Sdavidch return xts_verify_key(tfm, key, keylen) ?: 350255736Sdavidch riscv64_aes_setkey(&ctx->ctx1, key, keylen / 2) ?: 351255736Sdavidch riscv64_aes_setkey(&ctx->ctx2, key + keylen / 2, keylen / 2); 352255736Sdavidch} 353255736Sdavidch 354255736Sdavidchstatic int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc) 355255736Sdavidch{ 356255736Sdavidch struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 357255736Sdavidch const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 358255736Sdavidch int tail = req->cryptlen % AES_BLOCK_SIZE; 359255736Sdavidch struct scatterlist sg_src[2], sg_dst[2]; 360255736Sdavidch struct skcipher_request subreq; 361255736Sdavidch struct scatterlist *src, *dst; 362255736Sdavidch struct skcipher_walk walk; 363255736Sdavidch int err; 364255736Sdavidch 365255736Sdavidch if (req->cryptlen < AES_BLOCK_SIZE) 366255736Sdavidch return -EINVAL; 367255736Sdavidch 368255736Sdavidch /* Encrypt the IV with the tweak key to get the first tweak. */ 369255736Sdavidch kernel_vector_begin(); 370255736Sdavidch aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv); 371255736Sdavidch kernel_vector_end(); 372255736Sdavidch 373255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 374255736Sdavidch 375255736Sdavidch /* 376255736Sdavidch * If the message length isn't divisible by the AES block size and the 377255736Sdavidch * full message isn't available in one step of the scatterlist walk, 378255736Sdavidch * then separate off the last full block and the partial block. This 379255736Sdavidch * ensures that they are processed in the same call to the assembly 380255736Sdavidch * function, which is required for ciphertext stealing. 381255736Sdavidch */ 382255736Sdavidch if (unlikely(tail > 0 && walk.nbytes < walk.total)) { 383255736Sdavidch skcipher_walk_abort(&walk); 384255736Sdavidch 385255736Sdavidch skcipher_request_set_tfm(&subreq, tfm); 386255736Sdavidch skcipher_request_set_callback(&subreq, 387255736Sdavidch skcipher_request_flags(req), 388255736Sdavidch NULL, NULL); 389255736Sdavidch skcipher_request_set_crypt(&subreq, req->src, req->dst, 390255736Sdavidch req->cryptlen - tail - AES_BLOCK_SIZE, 391255736Sdavidch req->iv); 392255736Sdavidch req = &subreq; 393255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 394255736Sdavidch } else { 395255736Sdavidch tail = 0; 396255736Sdavidch } 397255736Sdavidch 398255736Sdavidch while (walk.nbytes) { 399255736Sdavidch unsigned int nbytes = walk.nbytes; 400255736Sdavidch 401255736Sdavidch if (nbytes < walk.total) 402255736Sdavidch nbytes = round_down(nbytes, AES_BLOCK_SIZE); 403255736Sdavidch 404255736Sdavidch kernel_vector_begin(); 405255736Sdavidch if (enc) 406255736Sdavidch aes_xts_encrypt_zvkned_zvbb_zvkg( 407255736Sdavidch &ctx->ctx1, walk.src.virt.addr, 408255736Sdavidch walk.dst.virt.addr, nbytes, req->iv); 409255736Sdavidch else 410255736Sdavidch aes_xts_decrypt_zvkned_zvbb_zvkg( 411255736Sdavidch &ctx->ctx1, walk.src.virt.addr, 412255736Sdavidch walk.dst.virt.addr, nbytes, req->iv); 413255736Sdavidch kernel_vector_end(); 414255736Sdavidch err = skcipher_walk_done(&walk, walk.nbytes - nbytes); 415255736Sdavidch } 416255736Sdavidch 417255736Sdavidch if (err || likely(!tail)) 418255736Sdavidch return err; 419255736Sdavidch 420255736Sdavidch /* Do ciphertext stealing with the last full block and partial block. */ 421255736Sdavidch 422255736Sdavidch dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 423255736Sdavidch if (req->dst != req->src) 424255736Sdavidch dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 425255736Sdavidch 426255736Sdavidch skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 427255736Sdavidch req->iv); 428255736Sdavidch 429255736Sdavidch err = skcipher_walk_virt(&walk, req, false); 430255736Sdavidch if (err) 431255736Sdavidch return err; 432255736Sdavidch 433255736Sdavidch kernel_vector_begin(); 434255736Sdavidch if (enc) 435255736Sdavidch aes_xts_encrypt_zvkned_zvbb_zvkg( 436255736Sdavidch &ctx->ctx1, walk.src.virt.addr, 437255736Sdavidch walk.dst.virt.addr, walk.nbytes, req->iv); 438255736Sdavidch else 439255736Sdavidch aes_xts_decrypt_zvkned_zvbb_zvkg( 440255736Sdavidch &ctx->ctx1, walk.src.virt.addr, 441255736Sdavidch walk.dst.virt.addr, walk.nbytes, req->iv); 442255736Sdavidch kernel_vector_end(); 443255736Sdavidch 444255736Sdavidch return skcipher_walk_done(&walk, 0); 445255736Sdavidch} 446255736Sdavidch 447255736Sdavidchstatic int riscv64_aes_xts_encrypt(struct skcipher_request *req) 448255736Sdavidch{ 449255736Sdavidch return riscv64_aes_xts_crypt(req, true); 450255736Sdavidch} 451255736Sdavidch 452255736Sdavidchstatic int riscv64_aes_xts_decrypt(struct skcipher_request *req) 453255736Sdavidch{ 454255736Sdavidch return riscv64_aes_xts_crypt(req, false); 455255736Sdavidch} 456255736Sdavidch 457255736Sdavidch/* Algorithm definitions */ 458255736Sdavidch 459255736Sdavidchstatic struct crypto_alg riscv64_zvkned_aes_cipher_alg = { 460255736Sdavidch .cra_flags = CRYPTO_ALG_TYPE_CIPHER, 461255736Sdavidch .cra_blocksize = AES_BLOCK_SIZE, 462255736Sdavidch .cra_ctxsize = sizeof(struct crypto_aes_ctx), 463255736Sdavidch .cra_priority = 300, 464255736Sdavidch .cra_name = "aes", 465255736Sdavidch .cra_driver_name = "aes-riscv64-zvkned", 466255736Sdavidch .cra_cipher = { 467255736Sdavidch .cia_min_keysize = AES_MIN_KEY_SIZE, 468255736Sdavidch .cia_max_keysize = AES_MAX_KEY_SIZE, 469255736Sdavidch .cia_setkey = riscv64_aes_setkey_cipher, 470255736Sdavidch .cia_encrypt = riscv64_aes_encrypt, 471255736Sdavidch .cia_decrypt = riscv64_aes_decrypt, 472255736Sdavidch }, 473255736Sdavidch .cra_module = THIS_MODULE, 474255736Sdavidch}; 475255736Sdavidch 476255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = { 477255736Sdavidch { 478255736Sdavidch .setkey = riscv64_aes_setkey_skcipher, 479255736Sdavidch .encrypt = riscv64_aes_ecb_encrypt, 480255736Sdavidch .decrypt = riscv64_aes_ecb_decrypt, 481255736Sdavidch .min_keysize = AES_MIN_KEY_SIZE, 482255736Sdavidch .max_keysize = AES_MAX_KEY_SIZE, 483258203Sedavis .walksize = 8 * AES_BLOCK_SIZE, /* matches LMUL=8 */ 484255736Sdavidch .base = { 485255736Sdavidch .cra_blocksize = AES_BLOCK_SIZE, 486255736Sdavidch .cra_ctxsize = sizeof(struct crypto_aes_ctx), 487255736Sdavidch .cra_priority = 300, 488255736Sdavidch .cra_name = "ecb(aes)", 489255736Sdavidch .cra_driver_name = "ecb-aes-riscv64-zvkned", 490255736Sdavidch .cra_module = THIS_MODULE, 491255736Sdavidch }, 492255736Sdavidch }, { 493255736Sdavidch .setkey = riscv64_aes_setkey_skcipher, 494255736Sdavidch .encrypt = riscv64_aes_cbc_encrypt, 495255736Sdavidch .decrypt = riscv64_aes_cbc_decrypt, 496255736Sdavidch .min_keysize = AES_MIN_KEY_SIZE, 497255736Sdavidch .max_keysize = AES_MAX_KEY_SIZE, 498255736Sdavidch .ivsize = AES_BLOCK_SIZE, 499255736Sdavidch .base = { 500255736Sdavidch .cra_blocksize = AES_BLOCK_SIZE, 501255736Sdavidch .cra_ctxsize = sizeof(struct crypto_aes_ctx), 502255736Sdavidch .cra_priority = 300, 503255736Sdavidch .cra_name = "cbc(aes)", 504255736Sdavidch .cra_driver_name = "cbc-aes-riscv64-zvkned", 505255736Sdavidch .cra_module = THIS_MODULE, 506255736Sdavidch }, 507255736Sdavidch }, { 508255736Sdavidch .setkey = riscv64_aes_setkey_skcipher, 509255736Sdavidch .encrypt = riscv64_aes_cbc_cts_encrypt, 510255736Sdavidch .decrypt = riscv64_aes_cbc_cts_decrypt, 511255736Sdavidch .min_keysize = AES_MIN_KEY_SIZE, 512255736Sdavidch .max_keysize = AES_MAX_KEY_SIZE, 513255736Sdavidch .ivsize = AES_BLOCK_SIZE, 514255736Sdavidch .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */ 515255736Sdavidch .base = { 516255736Sdavidch .cra_blocksize = AES_BLOCK_SIZE, 517255736Sdavidch .cra_ctxsize = sizeof(struct crypto_aes_ctx), 518255736Sdavidch .cra_priority = 300, 519255736Sdavidch .cra_name = "cts(cbc(aes))", 520255736Sdavidch .cra_driver_name = "cts-cbc-aes-riscv64-zvkned", 521255736Sdavidch .cra_module = THIS_MODULE, 522255736Sdavidch }, 523255736Sdavidch } 524255736Sdavidch}; 525255736Sdavidch 526255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = { 527255736Sdavidch .setkey = riscv64_aes_setkey_skcipher, 528255736Sdavidch .encrypt = riscv64_aes_ctr_crypt, 529255736Sdavidch .decrypt = riscv64_aes_ctr_crypt, 530255736Sdavidch .min_keysize = AES_MIN_KEY_SIZE, 531255736Sdavidch .max_keysize = AES_MAX_KEY_SIZE, 532255736Sdavidch .ivsize = AES_BLOCK_SIZE, 533255736Sdavidch .chunksize = AES_BLOCK_SIZE, 534255736Sdavidch .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */ 535255736Sdavidch .base = { 536255736Sdavidch .cra_blocksize = 1, 537255736Sdavidch .cra_ctxsize = sizeof(struct crypto_aes_ctx), 538255736Sdavidch .cra_priority = 300, 539255736Sdavidch .cra_name = "ctr(aes)", 540255736Sdavidch .cra_driver_name = "ctr-aes-riscv64-zvkned-zvkb", 541255736Sdavidch .cra_module = THIS_MODULE, 542255736Sdavidch }, 543255736Sdavidch}; 544255736Sdavidch 545255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg = { 546255736Sdavidch .setkey = riscv64_aes_xts_setkey, 547255736Sdavidch .encrypt = riscv64_aes_xts_encrypt, 548255736Sdavidch .decrypt = riscv64_aes_xts_decrypt, 549255736Sdavidch .min_keysize = 2 * AES_MIN_KEY_SIZE, 550255736Sdavidch .max_keysize = 2 * AES_MAX_KEY_SIZE, 551255736Sdavidch .ivsize = AES_BLOCK_SIZE, 552255736Sdavidch .chunksize = AES_BLOCK_SIZE, 553255736Sdavidch .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */ 554255736Sdavidch .base = { 555255736Sdavidch .cra_blocksize = AES_BLOCK_SIZE, 556255736Sdavidch .cra_ctxsize = sizeof(struct riscv64_aes_xts_ctx), 557255736Sdavidch .cra_priority = 300, 558255736Sdavidch .cra_name = "xts(aes)", 559255736Sdavidch .cra_driver_name = "xts-aes-riscv64-zvkned-zvbb-zvkg", 560255736Sdavidch .cra_module = THIS_MODULE, 561255736Sdavidch }, 562255736Sdavidch}; 563255736Sdavidch 564255736Sdavidchstatic inline bool riscv64_aes_xts_supported(void) 565255736Sdavidch{ 566255736Sdavidch return riscv_isa_extension_available(NULL, ZVBB) && 567255736Sdavidch riscv_isa_extension_available(NULL, ZVKG) && 568255736Sdavidch riscv_vector_vlen() < 2048 /* Implementation limitation */; 569255736Sdavidch} 570255736Sdavidch 571255736Sdavidchstatic int __init riscv64_aes_mod_init(void) 572255736Sdavidch{ 573255736Sdavidch int err = -ENODEV; 574255736Sdavidch 575255736Sdavidch if (riscv_isa_extension_available(NULL, ZVKNED) && 576255736Sdavidch riscv_vector_vlen() >= 128) { 577255736Sdavidch err = crypto_register_alg(&riscv64_zvkned_aes_cipher_alg); 578255736Sdavidch if (err) 579255736Sdavidch return err; 580255736Sdavidch 581255736Sdavidch err = crypto_register_skciphers( 582255736Sdavidch riscv64_zvkned_aes_skcipher_algs, 583255736Sdavidch ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); 584255736Sdavidch if (err) 585255736Sdavidch goto unregister_zvkned_cipher_alg; 586255736Sdavidch 587255736Sdavidch if (riscv_isa_extension_available(NULL, ZVKB)) { 588255736Sdavidch err = crypto_register_skcipher( 589255736Sdavidch &riscv64_zvkned_zvkb_aes_skcipher_alg); 590255736Sdavidch if (err) 591255736Sdavidch goto unregister_zvkned_skcipher_algs; 592255736Sdavidch } 593255736Sdavidch 594255736Sdavidch if (riscv64_aes_xts_supported()) { 595255736Sdavidch err = crypto_register_skcipher( 596255736Sdavidch &riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg); 597255736Sdavidch if (err) 598255736Sdavidch goto unregister_zvkned_zvkb_skcipher_alg; 599255736Sdavidch } 600255736Sdavidch } 601255736Sdavidch 602255736Sdavidch return err; 603255736Sdavidch 604255736Sdavidchunregister_zvkned_zvkb_skcipher_alg: 605255736Sdavidch if (riscv_isa_extension_available(NULL, ZVKB)) 606255736Sdavidch crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg); 607255736Sdavidchunregister_zvkned_skcipher_algs: 608255736Sdavidch crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs, 609255736Sdavidch ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); 610255736Sdavidchunregister_zvkned_cipher_alg: 611255736Sdavidch crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg); 612255736Sdavidch return err; 613255736Sdavidch} 614255736Sdavidch 615255736Sdavidchstatic void __exit riscv64_aes_mod_exit(void) 616255736Sdavidch{ 617255736Sdavidch if (riscv64_aes_xts_supported()) 618255736Sdavidch crypto_unregister_skcipher(&riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg); 619255736Sdavidch if (riscv_isa_extension_available(NULL, ZVKB)) 620255736Sdavidch crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg); 621255736Sdavidch crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs, 622255736Sdavidch ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); 623255736Sdavidch crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg); 624255736Sdavidch} 625255736Sdavidch 626255736Sdavidchmodule_init(riscv64_aes_mod_init); 627255736Sdavidchmodule_exit(riscv64_aes_mod_exit); 628255736Sdavidch 629255736SdavidchMODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)"); 630255736SdavidchMODULE_AUTHOR("Jerry Shih <jerry.shih@sifive.com>"); 631255736SdavidchMODULE_LICENSE("GPL"); 632255736SdavidchMODULE_ALIAS_CRYPTO("aes"); 633255736SdavidchMODULE_ALIAS_CRYPTO("ecb(aes)"); 634255736SdavidchMODULE_ALIAS_CRYPTO("cbc(aes)"); 635255736SdavidchMODULE_ALIAS_CRYPTO("cts(cbc(aes))"); 636255736SdavidchMODULE_ALIAS_CRYPTO("ctr(aes)"); 637255736SdavidchMODULE_ALIAS_CRYPTO("xts(aes)"); 638255736Sdavidch