1/* SPDX-License-Identifier: GPL-2.0-only */
2/*
3 * linux/arch/arm64/crypto/aes-modes.S - chaining mode wrappers for AES
4 *
5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8/* included by aes-ce.S and aes-neon.S */
9
10	.text
11	.align		4
12
13#ifndef MAX_STRIDE
14#define MAX_STRIDE	4
15#endif
16
17#if MAX_STRIDE == 4
18#define ST4(x...) x
19#define ST5(x...)
20#else
21#define ST4(x...)
22#define ST5(x...) x
23#endif
24
25SYM_FUNC_START_LOCAL(aes_encrypt_block4x)
26	encrypt_block4x	v0, v1, v2, v3, w3, x2, x8, w7
27	ret
28SYM_FUNC_END(aes_encrypt_block4x)
29
30SYM_FUNC_START_LOCAL(aes_decrypt_block4x)
31	decrypt_block4x	v0, v1, v2, v3, w3, x2, x8, w7
32	ret
33SYM_FUNC_END(aes_decrypt_block4x)
34
35#if MAX_STRIDE == 5
36SYM_FUNC_START_LOCAL(aes_encrypt_block5x)
37	encrypt_block5x	v0, v1, v2, v3, v4, w3, x2, x8, w7
38	ret
39SYM_FUNC_END(aes_encrypt_block5x)
40
41SYM_FUNC_START_LOCAL(aes_decrypt_block5x)
42	decrypt_block5x	v0, v1, v2, v3, v4, w3, x2, x8, w7
43	ret
44SYM_FUNC_END(aes_decrypt_block5x)
45#endif
46
47	/*
48	 * aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
49	 *		   int blocks)
50	 * aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
51	 *		   int blocks)
52	 */
53
54AES_FUNC_START(aes_ecb_encrypt)
55	frame_push	0
56
57	enc_prepare	w3, x2, x5
58
59.LecbencloopNx:
60	subs		w4, w4, #MAX_STRIDE
61	bmi		.Lecbenc1x
62	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
63ST4(	bl		aes_encrypt_block4x		)
64ST5(	ld1		{v4.16b}, [x1], #16		)
65ST5(	bl		aes_encrypt_block5x		)
66	st1		{v0.16b-v3.16b}, [x0], #64
67ST5(	st1		{v4.16b}, [x0], #16		)
68	b		.LecbencloopNx
69.Lecbenc1x:
70	adds		w4, w4, #MAX_STRIDE
71	beq		.Lecbencout
72.Lecbencloop:
73	ld1		{v0.16b}, [x1], #16		/* get next pt block */
74	encrypt_block	v0, w3, x2, x5, w6
75	st1		{v0.16b}, [x0], #16
76	subs		w4, w4, #1
77	bne		.Lecbencloop
78.Lecbencout:
79	frame_pop
80	ret
81AES_FUNC_END(aes_ecb_encrypt)
82
83
84AES_FUNC_START(aes_ecb_decrypt)
85	frame_push	0
86
87	dec_prepare	w3, x2, x5
88
89.LecbdecloopNx:
90	subs		w4, w4, #MAX_STRIDE
91	bmi		.Lecbdec1x
92	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
93ST4(	bl		aes_decrypt_block4x		)
94ST5(	ld1		{v4.16b}, [x1], #16		)
95ST5(	bl		aes_decrypt_block5x		)
96	st1		{v0.16b-v3.16b}, [x0], #64
97ST5(	st1		{v4.16b}, [x0], #16		)
98	b		.LecbdecloopNx
99.Lecbdec1x:
100	adds		w4, w4, #MAX_STRIDE
101	beq		.Lecbdecout
102.Lecbdecloop:
103	ld1		{v0.16b}, [x1], #16		/* get next ct block */
104	decrypt_block	v0, w3, x2, x5, w6
105	st1		{v0.16b}, [x0], #16
106	subs		w4, w4, #1
107	bne		.Lecbdecloop
108.Lecbdecout:
109	frame_pop
110	ret
111AES_FUNC_END(aes_ecb_decrypt)
112
113
114	/*
115	 * aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
116	 *		   int blocks, u8 iv[])
117	 * aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
118	 *		   int blocks, u8 iv[])
119	 * aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
120	 *			 int rounds, int blocks, u8 iv[],
121	 *			 u32 const rk2[]);
122	 * aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
123	 *			 int rounds, int blocks, u8 iv[],
124	 *			 u32 const rk2[]);
125	 */
126
127AES_FUNC_START(aes_essiv_cbc_encrypt)
128	ld1		{v4.16b}, [x5]			/* get iv */
129
130	mov		w8, #14				/* AES-256: 14 rounds */
131	enc_prepare	w8, x6, x7
132	encrypt_block	v4, w8, x6, x7, w9
133	enc_switch_key	w3, x2, x6
134	b		.Lcbcencloop4x
135
136AES_FUNC_START(aes_cbc_encrypt)
137	ld1		{v4.16b}, [x5]			/* get iv */
138	enc_prepare	w3, x2, x6
139
140.Lcbcencloop4x:
141	subs		w4, w4, #4
142	bmi		.Lcbcenc1x
143	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
144	eor		v0.16b, v0.16b, v4.16b		/* ..and xor with iv */
145	encrypt_block	v0, w3, x2, x6, w7
146	eor		v1.16b, v1.16b, v0.16b
147	encrypt_block	v1, w3, x2, x6, w7
148	eor		v2.16b, v2.16b, v1.16b
149	encrypt_block	v2, w3, x2, x6, w7
150	eor		v3.16b, v3.16b, v2.16b
151	encrypt_block	v3, w3, x2, x6, w7
152	st1		{v0.16b-v3.16b}, [x0], #64
153	mov		v4.16b, v3.16b
154	b		.Lcbcencloop4x
155.Lcbcenc1x:
156	adds		w4, w4, #4
157	beq		.Lcbcencout
158.Lcbcencloop:
159	ld1		{v0.16b}, [x1], #16		/* get next pt block */
160	eor		v4.16b, v4.16b, v0.16b		/* ..and xor with iv */
161	encrypt_block	v4, w3, x2, x6, w7
162	st1		{v4.16b}, [x0], #16
163	subs		w4, w4, #1
164	bne		.Lcbcencloop
165.Lcbcencout:
166	st1		{v4.16b}, [x5]			/* return iv */
167	ret
168AES_FUNC_END(aes_cbc_encrypt)
169AES_FUNC_END(aes_essiv_cbc_encrypt)
170
171AES_FUNC_START(aes_essiv_cbc_decrypt)
172	ld1		{cbciv.16b}, [x5]		/* get iv */
173
174	mov		w8, #14				/* AES-256: 14 rounds */
175	enc_prepare	w8, x6, x7
176	encrypt_block	cbciv, w8, x6, x7, w9
177	b		.Lessivcbcdecstart
178
179AES_FUNC_START(aes_cbc_decrypt)
180	ld1		{cbciv.16b}, [x5]		/* get iv */
181.Lessivcbcdecstart:
182	frame_push	0
183	dec_prepare	w3, x2, x6
184
185.LcbcdecloopNx:
186	subs		w4, w4, #MAX_STRIDE
187	bmi		.Lcbcdec1x
188	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
189#if MAX_STRIDE == 5
190	ld1		{v4.16b}, [x1], #16		/* get 1 ct block */
191	mov		v5.16b, v0.16b
192	mov		v6.16b, v1.16b
193	mov		v7.16b, v2.16b
194	bl		aes_decrypt_block5x
195	sub		x1, x1, #32
196	eor		v0.16b, v0.16b, cbciv.16b
197	eor		v1.16b, v1.16b, v5.16b
198	ld1		{v5.16b}, [x1], #16		/* reload 1 ct block */
199	ld1		{cbciv.16b}, [x1], #16		/* reload 1 ct block */
200	eor		v2.16b, v2.16b, v6.16b
201	eor		v3.16b, v3.16b, v7.16b
202	eor		v4.16b, v4.16b, v5.16b
203#else
204	mov		v4.16b, v0.16b
205	mov		v5.16b, v1.16b
206	mov		v6.16b, v2.16b
207	bl		aes_decrypt_block4x
208	sub		x1, x1, #16
209	eor		v0.16b, v0.16b, cbciv.16b
210	eor		v1.16b, v1.16b, v4.16b
211	ld1		{cbciv.16b}, [x1], #16		/* reload 1 ct block */
212	eor		v2.16b, v2.16b, v5.16b
213	eor		v3.16b, v3.16b, v6.16b
214#endif
215	st1		{v0.16b-v3.16b}, [x0], #64
216ST5(	st1		{v4.16b}, [x0], #16		)
217	b		.LcbcdecloopNx
218.Lcbcdec1x:
219	adds		w4, w4, #MAX_STRIDE
220	beq		.Lcbcdecout
221.Lcbcdecloop:
222	ld1		{v1.16b}, [x1], #16		/* get next ct block */
223	mov		v0.16b, v1.16b			/* ...and copy to v0 */
224	decrypt_block	v0, w3, x2, x6, w7
225	eor		v0.16b, v0.16b, cbciv.16b	/* xor with iv => pt */
226	mov		cbciv.16b, v1.16b		/* ct is next iv */
227	st1		{v0.16b}, [x0], #16
228	subs		w4, w4, #1
229	bne		.Lcbcdecloop
230.Lcbcdecout:
231	st1		{cbciv.16b}, [x5]		/* return iv */
232	frame_pop
233	ret
234AES_FUNC_END(aes_cbc_decrypt)
235AES_FUNC_END(aes_essiv_cbc_decrypt)
236
237
238	/*
239	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
240	 *		       int rounds, int bytes, u8 const iv[])
241	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
242	 *		       int rounds, int bytes, u8 const iv[])
243	 */
244
245AES_FUNC_START(aes_cbc_cts_encrypt)
246	adr_l		x8, .Lcts_permute_table
247	sub		x4, x4, #16
248	add		x9, x8, #32
249	add		x8, x8, x4
250	sub		x9, x9, x4
251	ld1		{v3.16b}, [x8]
252	ld1		{v4.16b}, [x9]
253
254	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
255	ld1		{v1.16b}, [x1]
256
257	ld1		{v5.16b}, [x5]			/* get iv */
258	enc_prepare	w3, x2, x6
259
260	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
261	tbl		v1.16b, {v1.16b}, v4.16b
262	encrypt_block	v0, w3, x2, x6, w7
263
264	eor		v1.16b, v1.16b, v0.16b
265	tbl		v0.16b, {v0.16b}, v3.16b
266	encrypt_block	v1, w3, x2, x6, w7
267
268	add		x4, x0, x4
269	st1		{v0.16b}, [x4]			/* overlapping stores */
270	st1		{v1.16b}, [x0]
271	ret
272AES_FUNC_END(aes_cbc_cts_encrypt)
273
274AES_FUNC_START(aes_cbc_cts_decrypt)
275	adr_l		x8, .Lcts_permute_table
276	sub		x4, x4, #16
277	add		x9, x8, #32
278	add		x8, x8, x4
279	sub		x9, x9, x4
280	ld1		{v3.16b}, [x8]
281	ld1		{v4.16b}, [x9]
282
283	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
284	ld1		{v1.16b}, [x1]
285
286	ld1		{v5.16b}, [x5]			/* get iv */
287	dec_prepare	w3, x2, x6
288
289	decrypt_block	v0, w3, x2, x6, w7
290	tbl		v2.16b, {v0.16b}, v3.16b
291	eor		v2.16b, v2.16b, v1.16b
292
293	tbx		v0.16b, {v1.16b}, v4.16b
294	decrypt_block	v0, w3, x2, x6, w7
295	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
296
297	add		x4, x0, x4
298	st1		{v2.16b}, [x4]			/* overlapping stores */
299	st1		{v0.16b}, [x0]
300	ret
301AES_FUNC_END(aes_cbc_cts_decrypt)
302
303	.section	".rodata", "a"
304	.align		6
305.Lcts_permute_table:
306	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
307	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
308	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
309	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
310	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
311	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
312	.previous
313
314	/*
315	 * This macro generates the code for CTR and XCTR mode.
316	 */
317.macro ctr_encrypt xctr
318	// Arguments
319	OUT		.req x0
320	IN		.req x1
321	KEY		.req x2
322	ROUNDS_W	.req w3
323	BYTES_W		.req w4
324	IV		.req x5
325	BYTE_CTR_W 	.req w6		// XCTR only
326	// Intermediate values
327	CTR_W		.req w11	// XCTR only
328	CTR		.req x11	// XCTR only
329	IV_PART		.req x12
330	BLOCKS		.req x13
331	BLOCKS_W	.req w13
332
333	frame_push	0
334
335	enc_prepare	ROUNDS_W, KEY, IV_PART
336	ld1		{vctr.16b}, [IV]
337
338	/*
339	 * Keep 64 bits of the IV in a register.  For CTR mode this lets us
340	 * easily increment the IV.  For XCTR mode this lets us efficiently XOR
341	 * the 64-bit counter with the IV.
342	 */
343	.if \xctr
344		umov		IV_PART, vctr.d[0]
345		lsr		CTR_W, BYTE_CTR_W, #4
346	.else
347		umov		IV_PART, vctr.d[1]
348		rev		IV_PART, IV_PART
349	.endif
350
351.LctrloopNx\xctr:
352	add		BLOCKS_W, BYTES_W, #15
353	sub		BYTES_W, BYTES_W, #MAX_STRIDE << 4
354	lsr		BLOCKS_W, BLOCKS_W, #4
355	mov		w8, #MAX_STRIDE
356	cmp		BLOCKS_W, w8
357	csel		BLOCKS_W, BLOCKS_W, w8, lt
358
359	/*
360	 * Set up the counter values in v0-v{MAX_STRIDE-1}.
361	 *
362	 * If we are encrypting less than MAX_STRIDE blocks, the tail block
363	 * handling code expects the last keystream block to be in
364	 * v{MAX_STRIDE-1}.  For example: if encrypting two blocks with
365	 * MAX_STRIDE=5, then v3 and v4 should have the next two counter blocks.
366	 */
367	.if \xctr
368		add		CTR, CTR, BLOCKS
369	.else
370		adds		IV_PART, IV_PART, BLOCKS
371	.endif
372	mov		v0.16b, vctr.16b
373	mov		v1.16b, vctr.16b
374	mov		v2.16b, vctr.16b
375	mov		v3.16b, vctr.16b
376ST5(	mov		v4.16b, vctr.16b		)
377	.if \xctr
378		sub		x6, CTR, #MAX_STRIDE - 1
379		sub		x7, CTR, #MAX_STRIDE - 2
380		sub		x8, CTR, #MAX_STRIDE - 3
381		sub		x9, CTR, #MAX_STRIDE - 4
382ST5(		sub		x10, CTR, #MAX_STRIDE - 5	)
383		eor		x6, x6, IV_PART
384		eor		x7, x7, IV_PART
385		eor		x8, x8, IV_PART
386		eor		x9, x9, IV_PART
387ST5(		eor		x10, x10, IV_PART		)
388		mov		v0.d[0], x6
389		mov		v1.d[0], x7
390		mov		v2.d[0], x8
391		mov		v3.d[0], x9
392ST5(		mov		v4.d[0], x10			)
393	.else
394		bcs		0f
395		.subsection	1
396		/*
397		 * This subsection handles carries.
398		 *
399		 * Conditional branching here is allowed with respect to time
400		 * invariance since the branches are dependent on the IV instead
401		 * of the plaintext or key.  This code is rarely executed in
402		 * practice anyway.
403		 */
404
405		/* Apply carry to outgoing counter. */
4060:		umov		x8, vctr.d[0]
407		rev		x8, x8
408		add		x8, x8, #1
409		rev		x8, x8
410		ins		vctr.d[0], x8
411
412		/*
413		 * Apply carry to counter blocks if needed.
414		 *
415		 * Since the carry flag was set, we know 0 <= IV_PART <
416		 * MAX_STRIDE.  Using the value of IV_PART we can determine how
417		 * many counter blocks need to be updated.
418		 */
419		cbz		IV_PART, 2f
420		adr		x16, 1f
421		sub		x16, x16, IV_PART, lsl #3
422		br		x16
423		bti		c
424		mov		v0.d[0], vctr.d[0]
425		bti		c
426		mov		v1.d[0], vctr.d[0]
427		bti		c
428		mov		v2.d[0], vctr.d[0]
429		bti		c
430		mov		v3.d[0], vctr.d[0]
431ST5(		bti		c				)
432ST5(		mov		v4.d[0], vctr.d[0]		)
4331:		b		2f
434		.previous
435
4362:		rev		x7, IV_PART
437		ins		vctr.d[1], x7
438		sub		x7, IV_PART, #MAX_STRIDE - 1
439		sub		x8, IV_PART, #MAX_STRIDE - 2
440		sub		x9, IV_PART, #MAX_STRIDE - 3
441		rev		x7, x7
442		rev		x8, x8
443		mov		v1.d[1], x7
444		rev		x9, x9
445ST5(		sub		x10, IV_PART, #MAX_STRIDE - 4	)
446		mov		v2.d[1], x8
447ST5(		rev		x10, x10			)
448		mov		v3.d[1], x9
449ST5(		mov		v4.d[1], x10			)
450	.endif
451
452	/*
453	 * If there are at least MAX_STRIDE blocks left, XOR the data with
454	 * keystream and store.  Otherwise jump to tail handling.
455	 */
456	tbnz		BYTES_W, #31, .Lctrtail\xctr
457	ld1		{v5.16b-v7.16b}, [IN], #48
458ST4(	bl		aes_encrypt_block4x		)
459ST5(	bl		aes_encrypt_block5x		)
460	eor		v0.16b, v5.16b, v0.16b
461ST4(	ld1		{v5.16b}, [IN], #16		)
462	eor		v1.16b, v6.16b, v1.16b
463ST5(	ld1		{v5.16b-v6.16b}, [IN], #32	)
464	eor		v2.16b, v7.16b, v2.16b
465	eor		v3.16b, v5.16b, v3.16b
466ST5(	eor		v4.16b, v6.16b, v4.16b		)
467	st1		{v0.16b-v3.16b}, [OUT], #64
468ST5(	st1		{v4.16b}, [OUT], #16		)
469	cbz		BYTES_W, .Lctrout\xctr
470	b		.LctrloopNx\xctr
471
472.Lctrout\xctr:
473	.if !\xctr
474		st1		{vctr.16b}, [IV] /* return next CTR value */
475	.endif
476	frame_pop
477	ret
478
479.Lctrtail\xctr:
480	/*
481	 * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
482	 *
483	 * This code expects the last keystream block to be in v{MAX_STRIDE-1}.
484	 * For example: if encrypting two blocks with MAX_STRIDE=5, then v3 and
485	 * v4 should have the next two counter blocks.
486	 *
487	 * This allows us to store the ciphertext by writing to overlapping
488	 * regions of memory.  Any invalid ciphertext blocks get overwritten by
489	 * correctly computed blocks.  This approach greatly simplifies the
490	 * logic for storing the ciphertext.
491	 */
492	mov		x16, #16
493	ands		w7, BYTES_W, #0xf
494	csel		x13, x7, x16, ne
495
496ST5(	cmp		BYTES_W, #64 - (MAX_STRIDE << 4))
497ST5(	csel		x14, x16, xzr, gt		)
498	cmp		BYTES_W, #48 - (MAX_STRIDE << 4)
499	csel		x15, x16, xzr, gt
500	cmp		BYTES_W, #32 - (MAX_STRIDE << 4)
501	csel		x16, x16, xzr, gt
502	cmp		BYTES_W, #16 - (MAX_STRIDE << 4)
503
504	adr_l		x9, .Lcts_permute_table
505	add		x9, x9, x13
506	ble		.Lctrtail1x\xctr
507
508ST5(	ld1		{v5.16b}, [IN], x14		)
509	ld1		{v6.16b}, [IN], x15
510	ld1		{v7.16b}, [IN], x16
511
512ST4(	bl		aes_encrypt_block4x		)
513ST5(	bl		aes_encrypt_block5x		)
514
515	ld1		{v8.16b}, [IN], x13
516	ld1		{v9.16b}, [IN]
517	ld1		{v10.16b}, [x9]
518
519ST4(	eor		v6.16b, v6.16b, v0.16b		)
520ST4(	eor		v7.16b, v7.16b, v1.16b		)
521ST4(	tbl		v3.16b, {v3.16b}, v10.16b	)
522ST4(	eor		v8.16b, v8.16b, v2.16b		)
523ST4(	eor		v9.16b, v9.16b, v3.16b		)
524
525ST5(	eor		v5.16b, v5.16b, v0.16b		)
526ST5(	eor		v6.16b, v6.16b, v1.16b		)
527ST5(	tbl		v4.16b, {v4.16b}, v10.16b	)
528ST5(	eor		v7.16b, v7.16b, v2.16b		)
529ST5(	eor		v8.16b, v8.16b, v3.16b		)
530ST5(	eor		v9.16b, v9.16b, v4.16b		)
531
532ST5(	st1		{v5.16b}, [OUT], x14		)
533	st1		{v6.16b}, [OUT], x15
534	st1		{v7.16b}, [OUT], x16
535	add		x13, x13, OUT
536	st1		{v9.16b}, [x13]		// overlapping stores
537	st1		{v8.16b}, [OUT]
538	b		.Lctrout\xctr
539
540.Lctrtail1x\xctr:
541	/*
542	 * Handle <= 16 bytes of plaintext
543	 *
544	 * This code always reads and writes 16 bytes.  To avoid out of bounds
545	 * accesses, XCTR and CTR modes must use a temporary buffer when
546	 * encrypting/decrypting less than 16 bytes.
547	 *
548	 * This code is unusual in that it loads the input and stores the output
549	 * relative to the end of the buffers rather than relative to the start.
550	 * This causes unusual behaviour when encrypting/decrypting less than 16
551	 * bytes; the end of the data is expected to be at the end of the
552	 * temporary buffer rather than the start of the data being at the start
553	 * of the temporary buffer.
554	 */
555	sub		x8, x7, #16
556	csel		x7, x7, x8, eq
557	add		IN, IN, x7
558	add		OUT, OUT, x7
559	ld1		{v5.16b}, [IN]
560	ld1		{v6.16b}, [OUT]
561ST5(	mov		v3.16b, v4.16b			)
562	encrypt_block	v3, ROUNDS_W, KEY, x8, w7
563	ld1		{v10.16b-v11.16b}, [x9]
564	tbl		v3.16b, {v3.16b}, v10.16b
565	sshr		v11.16b, v11.16b, #7
566	eor		v5.16b, v5.16b, v3.16b
567	bif		v5.16b, v6.16b, v11.16b
568	st1		{v5.16b}, [OUT]
569	b		.Lctrout\xctr
570
571	// Arguments
572	.unreq OUT
573	.unreq IN
574	.unreq KEY
575	.unreq ROUNDS_W
576	.unreq BYTES_W
577	.unreq IV
578	.unreq BYTE_CTR_W	// XCTR only
579	// Intermediate values
580	.unreq CTR_W		// XCTR only
581	.unreq CTR		// XCTR only
582	.unreq IV_PART
583	.unreq BLOCKS
584	.unreq BLOCKS_W
585.endm
586
587	/*
588	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
589	 *		   int bytes, u8 ctr[])
590	 *
591	 * The input and output buffers must always be at least 16 bytes even if
592	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
593	 * accesses will occur.  The data to be encrypted/decrypted is expected
594	 * to be at the end of this 16-byte temporary buffer rather than the
595	 * start.
596	 */
597
598AES_FUNC_START(aes_ctr_encrypt)
599	ctr_encrypt 0
600AES_FUNC_END(aes_ctr_encrypt)
601
602	/*
603	 * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
604	 *		   int bytes, u8 const iv[], int byte_ctr)
605	 *
606	 * The input and output buffers must always be at least 16 bytes even if
607	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
608	 * accesses will occur.  The data to be encrypted/decrypted is expected
609	 * to be at the end of this 16-byte temporary buffer rather than the
610	 * start.
611	 */
612
613AES_FUNC_START(aes_xctr_encrypt)
614	ctr_encrypt 1
615AES_FUNC_END(aes_xctr_encrypt)
616
617
618	/*
619	 * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
620	 *		   int bytes, u8 const rk2[], u8 iv[], int first)
621	 * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
622	 *		   int bytes, u8 const rk2[], u8 iv[], int first)
623	 */
624
625	.macro		next_tweak, out, in, tmp
626	sshr		\tmp\().2d,  \in\().2d,   #63
627	and		\tmp\().16b, \tmp\().16b, xtsmask.16b
628	add		\out\().2d,  \in\().2d,   \in\().2d
629	ext		\tmp\().16b, \tmp\().16b, \tmp\().16b, #8
630	eor		\out\().16b, \out\().16b, \tmp\().16b
631	.endm
632
633	.macro		xts_load_mask, tmp
634	movi		xtsmask.2s, #0x1
635	movi		\tmp\().2s, #0x87
636	uzp1		xtsmask.4s, xtsmask.4s, \tmp\().4s
637	.endm
638
639AES_FUNC_START(aes_xts_encrypt)
640	frame_push	0
641
642	ld1		{v4.16b}, [x6]
643	xts_load_mask	v8
644	cbz		w7, .Lxtsencnotfirst
645
646	enc_prepare	w3, x5, x8
647	xts_cts_skip_tw	w7, .LxtsencNx
648	encrypt_block	v4, w3, x5, x8, w7		/* first tweak */
649	enc_switch_key	w3, x2, x8
650	b		.LxtsencNx
651
652.Lxtsencnotfirst:
653	enc_prepare	w3, x2, x8
654.LxtsencloopNx:
655	next_tweak	v4, v4, v8
656.LxtsencNx:
657	subs		w4, w4, #64
658	bmi		.Lxtsenc1x
659	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
660	next_tweak	v5, v4, v8
661	eor		v0.16b, v0.16b, v4.16b
662	next_tweak	v6, v5, v8
663	eor		v1.16b, v1.16b, v5.16b
664	eor		v2.16b, v2.16b, v6.16b
665	next_tweak	v7, v6, v8
666	eor		v3.16b, v3.16b, v7.16b
667	bl		aes_encrypt_block4x
668	eor		v3.16b, v3.16b, v7.16b
669	eor		v0.16b, v0.16b, v4.16b
670	eor		v1.16b, v1.16b, v5.16b
671	eor		v2.16b, v2.16b, v6.16b
672	st1		{v0.16b-v3.16b}, [x0], #64
673	mov		v4.16b, v7.16b
674	cbz		w4, .Lxtsencret
675	xts_reload_mask	v8
676	b		.LxtsencloopNx
677.Lxtsenc1x:
678	adds		w4, w4, #64
679	beq		.Lxtsencout
680	subs		w4, w4, #16
681	bmi		.LxtsencctsNx
682.Lxtsencloop:
683	ld1		{v0.16b}, [x1], #16
684.Lxtsencctsout:
685	eor		v0.16b, v0.16b, v4.16b
686	encrypt_block	v0, w3, x2, x8, w7
687	eor		v0.16b, v0.16b, v4.16b
688	cbz		w4, .Lxtsencout
689	subs		w4, w4, #16
690	next_tweak	v4, v4, v8
691	bmi		.Lxtsenccts
692	st1		{v0.16b}, [x0], #16
693	b		.Lxtsencloop
694.Lxtsencout:
695	st1		{v0.16b}, [x0]
696.Lxtsencret:
697	st1		{v4.16b}, [x6]
698	frame_pop
699	ret
700
701.LxtsencctsNx:
702	mov		v0.16b, v3.16b
703	sub		x0, x0, #16
704.Lxtsenccts:
705	adr_l		x8, .Lcts_permute_table
706
707	add		x1, x1, w4, sxtw	/* rewind input pointer */
708	add		w4, w4, #16		/* # bytes in final block */
709	add		x9, x8, #32
710	add		x8, x8, x4
711	sub		x9, x9, x4
712	add		x4, x0, x4		/* output address of final block */
713
714	ld1		{v1.16b}, [x1]		/* load final block */
715	ld1		{v2.16b}, [x8]
716	ld1		{v3.16b}, [x9]
717
718	tbl		v2.16b, {v0.16b}, v2.16b
719	tbx		v0.16b, {v1.16b}, v3.16b
720	st1		{v2.16b}, [x4]			/* overlapping stores */
721	mov		w4, wzr
722	b		.Lxtsencctsout
723AES_FUNC_END(aes_xts_encrypt)
724
725AES_FUNC_START(aes_xts_decrypt)
726	frame_push	0
727
728	/* subtract 16 bytes if we are doing CTS */
729	sub		w8, w4, #0x10
730	tst		w4, #0xf
731	csel		w4, w4, w8, eq
732
733	ld1		{v4.16b}, [x6]
734	xts_load_mask	v8
735	xts_cts_skip_tw	w7, .Lxtsdecskiptw
736	cbz		w7, .Lxtsdecnotfirst
737
738	enc_prepare	w3, x5, x8
739	encrypt_block	v4, w3, x5, x8, w7		/* first tweak */
740.Lxtsdecskiptw:
741	dec_prepare	w3, x2, x8
742	b		.LxtsdecNx
743
744.Lxtsdecnotfirst:
745	dec_prepare	w3, x2, x8
746.LxtsdecloopNx:
747	next_tweak	v4, v4, v8
748.LxtsdecNx:
749	subs		w4, w4, #64
750	bmi		.Lxtsdec1x
751	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
752	next_tweak	v5, v4, v8
753	eor		v0.16b, v0.16b, v4.16b
754	next_tweak	v6, v5, v8
755	eor		v1.16b, v1.16b, v5.16b
756	eor		v2.16b, v2.16b, v6.16b
757	next_tweak	v7, v6, v8
758	eor		v3.16b, v3.16b, v7.16b
759	bl		aes_decrypt_block4x
760	eor		v3.16b, v3.16b, v7.16b
761	eor		v0.16b, v0.16b, v4.16b
762	eor		v1.16b, v1.16b, v5.16b
763	eor		v2.16b, v2.16b, v6.16b
764	st1		{v0.16b-v3.16b}, [x0], #64
765	mov		v4.16b, v7.16b
766	cbz		w4, .Lxtsdecout
767	xts_reload_mask	v8
768	b		.LxtsdecloopNx
769.Lxtsdec1x:
770	adds		w4, w4, #64
771	beq		.Lxtsdecout
772	subs		w4, w4, #16
773.Lxtsdecloop:
774	ld1		{v0.16b}, [x1], #16
775	bmi		.Lxtsdeccts
776.Lxtsdecctsout:
777	eor		v0.16b, v0.16b, v4.16b
778	decrypt_block	v0, w3, x2, x8, w7
779	eor		v0.16b, v0.16b, v4.16b
780	st1		{v0.16b}, [x0], #16
781	cbz		w4, .Lxtsdecout
782	subs		w4, w4, #16
783	next_tweak	v4, v4, v8
784	b		.Lxtsdecloop
785.Lxtsdecout:
786	st1		{v4.16b}, [x6]
787	frame_pop
788	ret
789
790.Lxtsdeccts:
791	adr_l		x8, .Lcts_permute_table
792
793	add		x1, x1, w4, sxtw	/* rewind input pointer */
794	add		w4, w4, #16		/* # bytes in final block */
795	add		x9, x8, #32
796	add		x8, x8, x4
797	sub		x9, x9, x4
798	add		x4, x0, x4		/* output address of final block */
799
800	next_tweak	v5, v4, v8
801
802	ld1		{v1.16b}, [x1]		/* load final block */
803	ld1		{v2.16b}, [x8]
804	ld1		{v3.16b}, [x9]
805
806	eor		v0.16b, v0.16b, v5.16b
807	decrypt_block	v0, w3, x2, x8, w7
808	eor		v0.16b, v0.16b, v5.16b
809
810	tbl		v2.16b, {v0.16b}, v2.16b
811	tbx		v0.16b, {v1.16b}, v3.16b
812
813	st1		{v2.16b}, [x4]			/* overlapping stores */
814	mov		w4, wzr
815	b		.Lxtsdecctsout
816AES_FUNC_END(aes_xts_decrypt)
817
818	/*
819	 * aes_mac_update(u8 const in[], u32 const rk[], int rounds,
820	 *		  int blocks, u8 dg[], int enc_before, int enc_after)
821	 */
822AES_FUNC_START(aes_mac_update)
823	ld1		{v0.16b}, [x4]			/* get dg */
824	enc_prepare	w2, x1, x7
825	cbz		w5, .Lmacloop4x
826
827	encrypt_block	v0, w2, x1, x7, w8
828
829.Lmacloop4x:
830	subs		w3, w3, #4
831	bmi		.Lmac1x
832	ld1		{v1.16b-v4.16b}, [x0], #64	/* get next pt block */
833	eor		v0.16b, v0.16b, v1.16b		/* ..and xor with dg */
834	encrypt_block	v0, w2, x1, x7, w8
835	eor		v0.16b, v0.16b, v2.16b
836	encrypt_block	v0, w2, x1, x7, w8
837	eor		v0.16b, v0.16b, v3.16b
838	encrypt_block	v0, w2, x1, x7, w8
839	eor		v0.16b, v0.16b, v4.16b
840	cmp		w3, wzr
841	csinv		x5, x6, xzr, eq
842	cbz		w5, .Lmacout
843	encrypt_block	v0, w2, x1, x7, w8
844	st1		{v0.16b}, [x4]			/* return dg */
845	cond_yield	.Lmacout, x7, x8
846	b		.Lmacloop4x
847.Lmac1x:
848	add		w3, w3, #4
849.Lmacloop:
850	cbz		w3, .Lmacout
851	ld1		{v1.16b}, [x0], #16		/* get next pt block */
852	eor		v0.16b, v0.16b, v1.16b		/* ..and xor with dg */
853
854	subs		w3, w3, #1
855	csinv		x5, x6, xzr, eq
856	cbz		w5, .Lmacout
857
858.Lmacenc:
859	encrypt_block	v0, w2, x1, x7, w8
860	b		.Lmacloop
861
862.Lmacout:
863	st1		{v0.16b}, [x4]			/* return dg */
864	mov		w0, w3
865	ret
866AES_FUNC_END(aes_mac_update)
867