1/* SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause */
2/*
3 * AES CTR mode by8 optimization with AVX instructions. (x86_64)
4 *
5 * Copyright(c) 2014 Intel Corporation.
6 *
7 * Contact Information:
8 * James Guilford <james.guilford@intel.com>
9 * Sean Gulley <sean.m.gulley@intel.com>
10 * Chandramouli Narayanan <mouli@linux.intel.com>
11 */
12/*
13 * This is AES128/192/256 CTR mode optimization implementation. It requires
14 * the support of Intel(R) AESNI and AVX instructions.
15 *
16 * This work was inspired by the AES CTR mode optimization published
17 * in Intel Optimized IPSEC Cryptographic library.
18 * Additional information on it can be found at:
19 *    https://github.com/intel/intel-ipsec-mb
20 */
21
22#include <linux/linkage.h>
23
24#define VMOVDQ		vmovdqu
25
26/*
27 * Note: the "x" prefix in these aliases means "this is an xmm register".  The
28 * alias prefixes have no relation to XCTR where the "X" prefix means "XOR
29 * counter".
30 */
31#define xdata0		%xmm0
32#define xdata1		%xmm1
33#define xdata2		%xmm2
34#define xdata3		%xmm3
35#define xdata4		%xmm4
36#define xdata5		%xmm5
37#define xdata6		%xmm6
38#define xdata7		%xmm7
39#define xcounter	%xmm8	// CTR mode only
40#define xiv		%xmm8	// XCTR mode only
41#define xbyteswap	%xmm9	// CTR mode only
42#define xtmp		%xmm9	// XCTR mode only
43#define xkey0		%xmm10
44#define xkey4		%xmm11
45#define xkey8		%xmm12
46#define xkey12		%xmm13
47#define xkeyA		%xmm14
48#define xkeyB		%xmm15
49
50#define p_in		%rdi
51#define p_iv		%rsi
52#define p_keys		%rdx
53#define p_out		%rcx
54#define num_bytes	%r8
55#define counter		%r9	// XCTR mode only
56#define tmp		%r10
57#define	DDQ_DATA	0
58#define	XDATA		1
59#define KEY_128		1
60#define KEY_192		2
61#define KEY_256		3
62
63.section .rodata
64.align 16
65
66byteswap_const:
67	.octa 0x000102030405060708090A0B0C0D0E0F
68ddq_low_msk:
69	.octa 0x0000000000000000FFFFFFFFFFFFFFFF
70ddq_high_add_1:
71	.octa 0x00000000000000010000000000000000
72ddq_add_1:
73	.octa 0x00000000000000000000000000000001
74ddq_add_2:
75	.octa 0x00000000000000000000000000000002
76ddq_add_3:
77	.octa 0x00000000000000000000000000000003
78ddq_add_4:
79	.octa 0x00000000000000000000000000000004
80ddq_add_5:
81	.octa 0x00000000000000000000000000000005
82ddq_add_6:
83	.octa 0x00000000000000000000000000000006
84ddq_add_7:
85	.octa 0x00000000000000000000000000000007
86ddq_add_8:
87	.octa 0x00000000000000000000000000000008
88
89.text
90
91/* generate a unique variable for ddq_add_x */
92
93/* generate a unique variable for xmm register */
94.macro setxdata n
95	var_xdata = %xmm\n
96.endm
97
98/* club the numeric 'id' to the symbol 'name' */
99
100.macro club name, id
101.altmacro
102	.if \name == XDATA
103		setxdata %\id
104	.endif
105.noaltmacro
106.endm
107
108/*
109 * do_aes num_in_par load_keys key_len
110 * This increments p_in, but not p_out
111 */
112.macro do_aes b, k, key_len, xctr
113	.set by, \b
114	.set load_keys, \k
115	.set klen, \key_len
116
117	.if (load_keys)
118		vmovdqa	0*16(p_keys), xkey0
119	.endif
120
121	.if \xctr
122		movq counter, xtmp
123		.set i, 0
124		.rept (by)
125			club XDATA, i
126			vpaddq	(ddq_add_1 + 16 * i)(%rip), xtmp, var_xdata
127			.set i, (i +1)
128		.endr
129		.set i, 0
130		.rept (by)
131			club	XDATA, i
132			vpxor	xiv, var_xdata, var_xdata
133			.set i, (i +1)
134		.endr
135	.else
136		vpshufb	xbyteswap, xcounter, xdata0
137		.set i, 1
138		.rept (by - 1)
139			club XDATA, i
140			vpaddq	(ddq_add_1 + 16 * (i - 1))(%rip), xcounter, var_xdata
141			vptest	ddq_low_msk(%rip), var_xdata
142			jnz 1f
143			vpaddq	ddq_high_add_1(%rip), var_xdata, var_xdata
144			vpaddq	ddq_high_add_1(%rip), xcounter, xcounter
145			1:
146			vpshufb	xbyteswap, var_xdata, var_xdata
147			.set i, (i +1)
148		.endr
149	.endif
150
151	vmovdqa	1*16(p_keys), xkeyA
152
153	vpxor	xkey0, xdata0, xdata0
154	.if \xctr
155		add $by, counter
156	.else
157		vpaddq	(ddq_add_1 + 16 * (by - 1))(%rip), xcounter, xcounter
158		vptest	ddq_low_msk(%rip), xcounter
159		jnz	1f
160		vpaddq	ddq_high_add_1(%rip), xcounter, xcounter
161		1:
162	.endif
163
164	.set i, 1
165	.rept (by - 1)
166		club XDATA, i
167		vpxor	xkey0, var_xdata, var_xdata
168		.set i, (i +1)
169	.endr
170
171	vmovdqa	2*16(p_keys), xkeyB
172
173	.set i, 0
174	.rept by
175		club XDATA, i
176		vaesenc	xkeyA, var_xdata, var_xdata		/* key 1 */
177		.set i, (i +1)
178	.endr
179
180	.if (klen == KEY_128)
181		.if (load_keys)
182			vmovdqa	3*16(p_keys), xkey4
183		.endif
184	.else
185		vmovdqa	3*16(p_keys), xkeyA
186	.endif
187
188	.set i, 0
189	.rept by
190		club XDATA, i
191		vaesenc	xkeyB, var_xdata, var_xdata		/* key 2 */
192		.set i, (i +1)
193	.endr
194
195	add	$(16*by), p_in
196
197	.if (klen == KEY_128)
198		vmovdqa	4*16(p_keys), xkeyB
199	.else
200		.if (load_keys)
201			vmovdqa	4*16(p_keys), xkey4
202		.endif
203	.endif
204
205	.set i, 0
206	.rept by
207		club XDATA, i
208		/* key 3 */
209		.if (klen == KEY_128)
210			vaesenc	xkey4, var_xdata, var_xdata
211		.else
212			vaesenc	xkeyA, var_xdata, var_xdata
213		.endif
214		.set i, (i +1)
215	.endr
216
217	vmovdqa	5*16(p_keys), xkeyA
218
219	.set i, 0
220	.rept by
221		club XDATA, i
222		/* key 4 */
223		.if (klen == KEY_128)
224			vaesenc	xkeyB, var_xdata, var_xdata
225		.else
226			vaesenc	xkey4, var_xdata, var_xdata
227		.endif
228		.set i, (i +1)
229	.endr
230
231	.if (klen == KEY_128)
232		.if (load_keys)
233			vmovdqa	6*16(p_keys), xkey8
234		.endif
235	.else
236		vmovdqa	6*16(p_keys), xkeyB
237	.endif
238
239	.set i, 0
240	.rept by
241		club XDATA, i
242		vaesenc	xkeyA, var_xdata, var_xdata		/* key 5 */
243		.set i, (i +1)
244	.endr
245
246	vmovdqa	7*16(p_keys), xkeyA
247
248	.set i, 0
249	.rept by
250		club XDATA, i
251		/* key 6 */
252		.if (klen == KEY_128)
253			vaesenc	xkey8, var_xdata, var_xdata
254		.else
255			vaesenc	xkeyB, var_xdata, var_xdata
256		.endif
257		.set i, (i +1)
258	.endr
259
260	.if (klen == KEY_128)
261		vmovdqa	8*16(p_keys), xkeyB
262	.else
263		.if (load_keys)
264			vmovdqa	8*16(p_keys), xkey8
265		.endif
266	.endif
267
268	.set i, 0
269	.rept by
270		club XDATA, i
271		vaesenc	xkeyA, var_xdata, var_xdata		/* key 7 */
272		.set i, (i +1)
273	.endr
274
275	.if (klen == KEY_128)
276		.if (load_keys)
277			vmovdqa	9*16(p_keys), xkey12
278		.endif
279	.else
280		vmovdqa	9*16(p_keys), xkeyA
281	.endif
282
283	.set i, 0
284	.rept by
285		club XDATA, i
286		/* key 8 */
287		.if (klen == KEY_128)
288			vaesenc	xkeyB, var_xdata, var_xdata
289		.else
290			vaesenc	xkey8, var_xdata, var_xdata
291		.endif
292		.set i, (i +1)
293	.endr
294
295	vmovdqa	10*16(p_keys), xkeyB
296
297	.set i, 0
298	.rept by
299		club XDATA, i
300		/* key 9 */
301		.if (klen == KEY_128)
302			vaesenc	xkey12, var_xdata, var_xdata
303		.else
304			vaesenc	xkeyA, var_xdata, var_xdata
305		.endif
306		.set i, (i +1)
307	.endr
308
309	.if (klen != KEY_128)
310		vmovdqa	11*16(p_keys), xkeyA
311	.endif
312
313	.set i, 0
314	.rept by
315		club XDATA, i
316		/* key 10 */
317		.if (klen == KEY_128)
318			vaesenclast	xkeyB, var_xdata, var_xdata
319		.else
320			vaesenc	xkeyB, var_xdata, var_xdata
321		.endif
322		.set i, (i +1)
323	.endr
324
325	.if (klen != KEY_128)
326		.if (load_keys)
327			vmovdqa	12*16(p_keys), xkey12
328		.endif
329
330		.set i, 0
331		.rept by
332			club XDATA, i
333			vaesenc	xkeyA, var_xdata, var_xdata	/* key 11 */
334			.set i, (i +1)
335		.endr
336
337		.if (klen == KEY_256)
338			vmovdqa	13*16(p_keys), xkeyA
339		.endif
340
341		.set i, 0
342		.rept by
343			club XDATA, i
344			.if (klen == KEY_256)
345				/* key 12 */
346				vaesenc	xkey12, var_xdata, var_xdata
347			.else
348				vaesenclast xkey12, var_xdata, var_xdata
349			.endif
350			.set i, (i +1)
351		.endr
352
353		.if (klen == KEY_256)
354			vmovdqa	14*16(p_keys), xkeyB
355
356			.set i, 0
357			.rept by
358				club XDATA, i
359				/* key 13 */
360				vaesenc	xkeyA, var_xdata, var_xdata
361				.set i, (i +1)
362			.endr
363
364			.set i, 0
365			.rept by
366				club XDATA, i
367				/* key 14 */
368				vaesenclast	xkeyB, var_xdata, var_xdata
369				.set i, (i +1)
370			.endr
371		.endif
372	.endif
373
374	.set i, 0
375	.rept (by / 2)
376		.set j, (i+1)
377		VMOVDQ	(i*16 - 16*by)(p_in), xkeyA
378		VMOVDQ	(j*16 - 16*by)(p_in), xkeyB
379		club XDATA, i
380		vpxor	xkeyA, var_xdata, var_xdata
381		club XDATA, j
382		vpxor	xkeyB, var_xdata, var_xdata
383		.set i, (i+2)
384	.endr
385
386	.if (i < by)
387		VMOVDQ	(i*16 - 16*by)(p_in), xkeyA
388		club XDATA, i
389		vpxor	xkeyA, var_xdata, var_xdata
390	.endif
391
392	.set i, 0
393	.rept by
394		club XDATA, i
395		VMOVDQ	var_xdata, i*16(p_out)
396		.set i, (i+1)
397	.endr
398.endm
399
400.macro do_aes_load val, key_len, xctr
401	do_aes \val, 1, \key_len, \xctr
402.endm
403
404.macro do_aes_noload val, key_len, xctr
405	do_aes \val, 0, \key_len, \xctr
406.endm
407
408/* main body of aes ctr load */
409
410.macro do_aes_ctrmain key_len, xctr
411	cmp	$16, num_bytes
412	jb	.Ldo_return2\xctr\key_len
413
414	.if \xctr
415		shr	$4, counter
416		vmovdqu	(p_iv), xiv
417	.else
418		vmovdqa	byteswap_const(%rip), xbyteswap
419		vmovdqu	(p_iv), xcounter
420		vpshufb	xbyteswap, xcounter, xcounter
421	.endif
422
423	mov	num_bytes, tmp
424	and	$(7*16), tmp
425	jz	.Lmult_of_8_blks\xctr\key_len
426
427	/* 1 <= tmp <= 7 */
428	cmp	$(4*16), tmp
429	jg	.Lgt4\xctr\key_len
430	je	.Leq4\xctr\key_len
431
432.Llt4\xctr\key_len:
433	cmp	$(2*16), tmp
434	jg	.Leq3\xctr\key_len
435	je	.Leq2\xctr\key_len
436
437.Leq1\xctr\key_len:
438	do_aes_load	1, \key_len, \xctr
439	add	$(1*16), p_out
440	and	$(~7*16), num_bytes
441	jz	.Ldo_return2\xctr\key_len
442	jmp	.Lmain_loop2\xctr\key_len
443
444.Leq2\xctr\key_len:
445	do_aes_load	2, \key_len, \xctr
446	add	$(2*16), p_out
447	and	$(~7*16), num_bytes
448	jz	.Ldo_return2\xctr\key_len
449	jmp	.Lmain_loop2\xctr\key_len
450
451
452.Leq3\xctr\key_len:
453	do_aes_load	3, \key_len, \xctr
454	add	$(3*16), p_out
455	and	$(~7*16), num_bytes
456	jz	.Ldo_return2\xctr\key_len
457	jmp	.Lmain_loop2\xctr\key_len
458
459.Leq4\xctr\key_len:
460	do_aes_load	4, \key_len, \xctr
461	add	$(4*16), p_out
462	and	$(~7*16), num_bytes
463	jz	.Ldo_return2\xctr\key_len
464	jmp	.Lmain_loop2\xctr\key_len
465
466.Lgt4\xctr\key_len:
467	cmp	$(6*16), tmp
468	jg	.Leq7\xctr\key_len
469	je	.Leq6\xctr\key_len
470
471.Leq5\xctr\key_len:
472	do_aes_load	5, \key_len, \xctr
473	add	$(5*16), p_out
474	and	$(~7*16), num_bytes
475	jz	.Ldo_return2\xctr\key_len
476	jmp	.Lmain_loop2\xctr\key_len
477
478.Leq6\xctr\key_len:
479	do_aes_load	6, \key_len, \xctr
480	add	$(6*16), p_out
481	and	$(~7*16), num_bytes
482	jz	.Ldo_return2\xctr\key_len
483	jmp	.Lmain_loop2\xctr\key_len
484
485.Leq7\xctr\key_len:
486	do_aes_load	7, \key_len, \xctr
487	add	$(7*16), p_out
488	and	$(~7*16), num_bytes
489	jz	.Ldo_return2\xctr\key_len
490	jmp	.Lmain_loop2\xctr\key_len
491
492.Lmult_of_8_blks\xctr\key_len:
493	.if (\key_len != KEY_128)
494		vmovdqa	0*16(p_keys), xkey0
495		vmovdqa	4*16(p_keys), xkey4
496		vmovdqa	8*16(p_keys), xkey8
497		vmovdqa	12*16(p_keys), xkey12
498	.else
499		vmovdqa	0*16(p_keys), xkey0
500		vmovdqa	3*16(p_keys), xkey4
501		vmovdqa	6*16(p_keys), xkey8
502		vmovdqa	9*16(p_keys), xkey12
503	.endif
504.align 16
505.Lmain_loop2\xctr\key_len:
506	/* num_bytes is a multiple of 8 and >0 */
507	do_aes_noload	8, \key_len, \xctr
508	add	$(8*16), p_out
509	sub	$(8*16), num_bytes
510	jne	.Lmain_loop2\xctr\key_len
511
512.Ldo_return2\xctr\key_len:
513	.if !\xctr
514		/* return updated IV */
515		vpshufb	xbyteswap, xcounter, xcounter
516		vmovdqu	xcounter, (p_iv)
517	.endif
518	RET
519.endm
520
521/*
522 * routine to do AES128 CTR enc/decrypt "by8"
523 * XMM registers are clobbered.
524 * Saving/restoring must be done at a higher level
525 * aes_ctr_enc_128_avx_by8(void *in, void *iv, void *keys, void *out,
526 *			unsigned int num_bytes)
527 */
528SYM_FUNC_START(aes_ctr_enc_128_avx_by8)
529	/* call the aes main loop */
530	do_aes_ctrmain KEY_128 0
531
532SYM_FUNC_END(aes_ctr_enc_128_avx_by8)
533
534/*
535 * routine to do AES192 CTR enc/decrypt "by8"
536 * XMM registers are clobbered.
537 * Saving/restoring must be done at a higher level
538 * aes_ctr_enc_192_avx_by8(void *in, void *iv, void *keys, void *out,
539 *			unsigned int num_bytes)
540 */
541SYM_FUNC_START(aes_ctr_enc_192_avx_by8)
542	/* call the aes main loop */
543	do_aes_ctrmain KEY_192 0
544
545SYM_FUNC_END(aes_ctr_enc_192_avx_by8)
546
547/*
548 * routine to do AES256 CTR enc/decrypt "by8"
549 * XMM registers are clobbered.
550 * Saving/restoring must be done at a higher level
551 * aes_ctr_enc_256_avx_by8(void *in, void *iv, void *keys, void *out,
552 *			unsigned int num_bytes)
553 */
554SYM_FUNC_START(aes_ctr_enc_256_avx_by8)
555	/* call the aes main loop */
556	do_aes_ctrmain KEY_256 0
557
558SYM_FUNC_END(aes_ctr_enc_256_avx_by8)
559
560/*
561 * routine to do AES128 XCTR enc/decrypt "by8"
562 * XMM registers are clobbered.
563 * Saving/restoring must be done at a higher level
564 * aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv, const void *keys,
565 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
566 */
567SYM_FUNC_START(aes_xctr_enc_128_avx_by8)
568	/* call the aes main loop */
569	do_aes_ctrmain KEY_128 1
570
571SYM_FUNC_END(aes_xctr_enc_128_avx_by8)
572
573/*
574 * routine to do AES192 XCTR enc/decrypt "by8"
575 * XMM registers are clobbered.
576 * Saving/restoring must be done at a higher level
577 * aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv, const void *keys,
578 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
579 */
580SYM_FUNC_START(aes_xctr_enc_192_avx_by8)
581	/* call the aes main loop */
582	do_aes_ctrmain KEY_192 1
583
584SYM_FUNC_END(aes_xctr_enc_192_avx_by8)
585
586/*
587 * routine to do AES256 XCTR enc/decrypt "by8"
588 * XMM registers are clobbered.
589 * Saving/restoring must be done at a higher level
590 * aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv, const void *keys,
591 * 	u8* out, unsigned int num_bytes, unsigned int byte_ctr)
592 */
593SYM_FUNC_START(aes_xctr_enc_256_avx_by8)
594	/* call the aes main loop */
595	do_aes_ctrmain KEY_256 1
596
597SYM_FUNC_END(aes_xctr_enc_256_avx_by8)
598