1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * BPF JIT compiler for ARM64
4 *
5 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6 */
7
8#define pr_fmt(fmt) "bpf_jit: " fmt
9
10#include <linux/bitfield.h>
11#include <linux/bpf.h>
12#include <linux/filter.h>
13#include <linux/memory.h>
14#include <linux/printk.h>
15#include <linux/slab.h>
16
17#include <asm/asm-extable.h>
18#include <asm/byteorder.h>
19#include <asm/cacheflush.h>
20#include <asm/debug-monitors.h>
21#include <asm/insn.h>
22#include <asm/patching.h>
23#include <asm/set_memory.h>
24
25#include "bpf_jit.h"
26
27#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
28#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
29#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
30#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
31#define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
32
33#define check_imm(bits, imm) do {				\
34	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
35	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
36		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
37			i, imm, imm);				\
38		return -EINVAL;					\
39	}							\
40} while (0)
41#define check_imm19(imm) check_imm(19, imm)
42#define check_imm26(imm) check_imm(26, imm)
43
44/* Map BPF registers to A64 registers */
45static const int bpf2a64[] = {
46	/* return value from in-kernel function, and exit value from eBPF */
47	[BPF_REG_0] = A64_R(7),
48	/* arguments from eBPF program to in-kernel function */
49	[BPF_REG_1] = A64_R(0),
50	[BPF_REG_2] = A64_R(1),
51	[BPF_REG_3] = A64_R(2),
52	[BPF_REG_4] = A64_R(3),
53	[BPF_REG_5] = A64_R(4),
54	/* callee saved registers that in-kernel function will preserve */
55	[BPF_REG_6] = A64_R(19),
56	[BPF_REG_7] = A64_R(20),
57	[BPF_REG_8] = A64_R(21),
58	[BPF_REG_9] = A64_R(22),
59	/* read-only frame pointer to access stack */
60	[BPF_REG_FP] = A64_R(25),
61	/* temporary registers for BPF JIT */
62	[TMP_REG_1] = A64_R(10),
63	[TMP_REG_2] = A64_R(11),
64	[TMP_REG_3] = A64_R(12),
65	/* tail_call_cnt */
66	[TCALL_CNT] = A64_R(26),
67	/* temporary register for blinding constants */
68	[BPF_REG_AX] = A64_R(9),
69	[FP_BOTTOM] = A64_R(27),
70};
71
72struct jit_ctx {
73	const struct bpf_prog *prog;
74	int idx;
75	int epilogue_offset;
76	int *offset;
77	int exentry_idx;
78	__le32 *image;
79	__le32 *ro_image;
80	u32 stack_size;
81	int fpb_offset;
82};
83
84struct bpf_plt {
85	u32 insn_ldr; /* load target */
86	u32 insn_br;  /* branch to target */
87	u64 target;   /* target value */
88};
89
90#define PLT_TARGET_SIZE   sizeof_field(struct bpf_plt, target)
91#define PLT_TARGET_OFFSET offsetof(struct bpf_plt, target)
92
93static inline void emit(const u32 insn, struct jit_ctx *ctx)
94{
95	if (ctx->image != NULL)
96		ctx->image[ctx->idx] = cpu_to_le32(insn);
97
98	ctx->idx++;
99}
100
101static inline void emit_a64_mov_i(const int is64, const int reg,
102				  const s32 val, struct jit_ctx *ctx)
103{
104	u16 hi = val >> 16;
105	u16 lo = val & 0xffff;
106
107	if (hi & 0x8000) {
108		if (hi == 0xffff) {
109			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
110		} else {
111			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
112			if (lo != 0xffff)
113				emit(A64_MOVK(is64, reg, lo, 0), ctx);
114		}
115	} else {
116		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
117		if (hi)
118			emit(A64_MOVK(is64, reg, hi, 16), ctx);
119	}
120}
121
122static int i64_i16_blocks(const u64 val, bool inverse)
123{
124	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
125	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
126	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
127	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
128}
129
130static inline void emit_a64_mov_i64(const int reg, const u64 val,
131				    struct jit_ctx *ctx)
132{
133	u64 nrm_tmp = val, rev_tmp = ~val;
134	bool inverse;
135	int shift;
136
137	if (!(nrm_tmp >> 32))
138		return emit_a64_mov_i(0, reg, (u32)val, ctx);
139
140	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
141	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
142					  (fls64(nrm_tmp) - 1)), 16), 0);
143	if (inverse)
144		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
145	else
146		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
147	shift -= 16;
148	while (shift >= 0) {
149		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
150			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
151		shift -= 16;
152	}
153}
154
155static inline void emit_bti(u32 insn, struct jit_ctx *ctx)
156{
157	if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
158		emit(insn, ctx);
159}
160
161/*
162 * Kernel addresses in the vmalloc space use at most 48 bits, and the
163 * remaining bits are guaranteed to be 0x1. So we can compose the address
164 * with a fixed length movn/movk/movk sequence.
165 */
166static inline void emit_addr_mov_i64(const int reg, const u64 val,
167				     struct jit_ctx *ctx)
168{
169	u64 tmp = val;
170	int shift = 0;
171
172	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
173	while (shift < 32) {
174		tmp >>= 16;
175		shift += 16;
176		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
177	}
178}
179
180static inline void emit_call(u64 target, struct jit_ctx *ctx)
181{
182	u8 tmp = bpf2a64[TMP_REG_1];
183
184	emit_addr_mov_i64(tmp, target, ctx);
185	emit(A64_BLR(tmp), ctx);
186}
187
188static inline int bpf2a64_offset(int bpf_insn, int off,
189				 const struct jit_ctx *ctx)
190{
191	/* BPF JMP offset is relative to the next instruction */
192	bpf_insn++;
193	/*
194	 * Whereas arm64 branch instructions encode the offset
195	 * from the branch itself, so we must subtract 1 from the
196	 * instruction offset.
197	 */
198	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
199}
200
201static void jit_fill_hole(void *area, unsigned int size)
202{
203	__le32 *ptr;
204	/* We are guaranteed to have aligned memory. */
205	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
206		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
207}
208
209int bpf_arch_text_invalidate(void *dst, size_t len)
210{
211	if (!aarch64_insn_set(dst, AARCH64_BREAK_FAULT, len))
212		return -EINVAL;
213
214	return 0;
215}
216
217static inline int epilogue_offset(const struct jit_ctx *ctx)
218{
219	int to = ctx->epilogue_offset;
220	int from = ctx->idx;
221
222	return to - from;
223}
224
225static bool is_addsub_imm(u32 imm)
226{
227	/* Either imm12 or shifted imm12. */
228	return !(imm & ~0xfff) || !(imm & ~0xfff000);
229}
230
231/*
232 * There are 3 types of AArch64 LDR/STR (immediate) instruction:
233 * Post-index, Pre-index, Unsigned offset.
234 *
235 * For BPF ldr/str, the "unsigned offset" type is sufficient.
236 *
237 * "Unsigned offset" type LDR(immediate) format:
238 *
239 *    3                   2                   1                   0
240 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
241 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
242 * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
243 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
244 * scale
245 *
246 * "Unsigned offset" type STR(immediate) format:
247 *    3                   2                   1                   0
248 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
249 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
250 * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
251 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
252 * scale
253 *
254 * The offset is calculated from imm12 and scale in the following way:
255 *
256 * offset = (u64)imm12 << scale
257 */
258static bool is_lsi_offset(int offset, int scale)
259{
260	if (offset < 0)
261		return false;
262
263	if (offset > (0xFFF << scale))
264		return false;
265
266	if (offset & ((1 << scale) - 1))
267		return false;
268
269	return true;
270}
271
272/* generated prologue:
273 *      bti c // if CONFIG_ARM64_BTI_KERNEL
274 *      mov x9, lr
275 *      nop  // POKE_OFFSET
276 *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
277 *      stp x29, lr, [sp, #-16]!
278 *      mov x29, sp
279 *      stp x19, x20, [sp, #-16]!
280 *      stp x21, x22, [sp, #-16]!
281 *      stp x25, x26, [sp, #-16]!
282 *      stp x27, x28, [sp, #-16]!
283 *      mov x25, sp
284 *      mov tcc, #0
285 *      // PROLOGUE_OFFSET
286 */
287
288#define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0)
289#define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
290
291/* Offset of nop instruction in bpf prog entry to be poked */
292#define POKE_OFFSET (BTI_INSNS + 1)
293
294/* Tail call offset to jump into */
295#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 8)
296
297static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
298			  bool is_exception_cb)
299{
300	const struct bpf_prog *prog = ctx->prog;
301	const bool is_main_prog = !bpf_is_subprog(prog);
302	const u8 r6 = bpf2a64[BPF_REG_6];
303	const u8 r7 = bpf2a64[BPF_REG_7];
304	const u8 r8 = bpf2a64[BPF_REG_8];
305	const u8 r9 = bpf2a64[BPF_REG_9];
306	const u8 fp = bpf2a64[BPF_REG_FP];
307	const u8 tcc = bpf2a64[TCALL_CNT];
308	const u8 fpb = bpf2a64[FP_BOTTOM];
309	const int idx0 = ctx->idx;
310	int cur_offset;
311
312	/*
313	 * BPF prog stack layout
314	 *
315	 *                         high
316	 * original A64_SP =>   0:+-----+ BPF prologue
317	 *                        |FP/LR|
318	 * current A64_FP =>  -16:+-----+
319	 *                        | ... | callee saved registers
320	 * BPF fp register => -64:+-----+ <= (BPF_FP)
321	 *                        |     |
322	 *                        | ... | BPF prog stack
323	 *                        |     |
324	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
325	 *                        |RSVD | padding
326	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
327	 *                        |     |
328	 *                        | ... | Function call stack
329	 *                        |     |
330	 *                        +-----+
331	 *                          low
332	 *
333	 */
334
335	/* bpf function may be invoked by 3 instruction types:
336	 * 1. bl, attached via freplace to bpf prog via short jump
337	 * 2. br, attached via freplace to bpf prog via long jump
338	 * 3. blr, working as a function pointer, used by emit_call.
339	 * So BTI_JC should used here to support both br and blr.
340	 */
341	emit_bti(A64_BTI_JC, ctx);
342
343	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
344	emit(A64_NOP, ctx);
345
346	if (!is_exception_cb) {
347		/* Sign lr */
348		if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
349			emit(A64_PACIASP, ctx);
350		/* Save FP and LR registers to stay align with ARM64 AAPCS */
351		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
352		emit(A64_MOV(1, A64_FP, A64_SP), ctx);
353
354		/* Save callee-saved registers */
355		emit(A64_PUSH(r6, r7, A64_SP), ctx);
356		emit(A64_PUSH(r8, r9, A64_SP), ctx);
357		emit(A64_PUSH(fp, tcc, A64_SP), ctx);
358		emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
359	} else {
360		/*
361		 * Exception callback receives FP of Main Program as third
362		 * parameter
363		 */
364		emit(A64_MOV(1, A64_FP, A64_R(2)), ctx);
365		/*
366		 * Main Program already pushed the frame record and the
367		 * callee-saved registers. The exception callback will not push
368		 * anything and re-use the main program's stack.
369		 *
370		 * 10 registers are on the stack
371		 */
372		emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
373	}
374
375	/* Set up BPF prog stack base register */
376	emit(A64_MOV(1, fp, A64_SP), ctx);
377
378	if (!ebpf_from_cbpf && is_main_prog) {
379		/* Initialize tail_call_cnt */
380		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
381
382		cur_offset = ctx->idx - idx0;
383		if (cur_offset != PROLOGUE_OFFSET) {
384			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
385				    cur_offset, PROLOGUE_OFFSET);
386			return -1;
387		}
388
389		/* BTI landing pad for the tail call, done with a BR */
390		emit_bti(A64_BTI_J, ctx);
391	}
392
393	/*
394	 * Program acting as exception boundary should save all ARM64
395	 * Callee-saved registers as the exception callback needs to recover
396	 * all ARM64 Callee-saved registers in its epilogue.
397	 */
398	if (prog->aux->exception_boundary) {
399		/*
400		 * As we are pushing two more registers, BPF_FP should be moved
401		 * 16 bytes
402		 */
403		emit(A64_SUB_I(1, fp, fp, 16), ctx);
404		emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
405	}
406
407	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
408
409	/* Stack must be multiples of 16B */
410	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
411
412	/* Set up function call stack */
413	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
414	return 0;
415}
416
417static int out_offset = -1; /* initialized on the first pass of build_body() */
418static int emit_bpf_tail_call(struct jit_ctx *ctx)
419{
420	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
421	const u8 r2 = bpf2a64[BPF_REG_2];
422	const u8 r3 = bpf2a64[BPF_REG_3];
423
424	const u8 tmp = bpf2a64[TMP_REG_1];
425	const u8 prg = bpf2a64[TMP_REG_2];
426	const u8 tcc = bpf2a64[TCALL_CNT];
427	const int idx0 = ctx->idx;
428#define cur_offset (ctx->idx - idx0)
429#define jmp_offset (out_offset - (cur_offset))
430	size_t off;
431
432	/* if (index >= array->map.max_entries)
433	 *     goto out;
434	 */
435	off = offsetof(struct bpf_array, map.max_entries);
436	emit_a64_mov_i64(tmp, off, ctx);
437	emit(A64_LDR32(tmp, r2, tmp), ctx);
438	emit(A64_MOV(0, r3, r3), ctx);
439	emit(A64_CMP(0, r3, tmp), ctx);
440	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
441
442	/*
443	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
444	 *     goto out;
445	 * tail_call_cnt++;
446	 */
447	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
448	emit(A64_CMP(1, tcc, tmp), ctx);
449	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
450	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
451
452	/* prog = array->ptrs[index];
453	 * if (prog == NULL)
454	 *     goto out;
455	 */
456	off = offsetof(struct bpf_array, ptrs);
457	emit_a64_mov_i64(tmp, off, ctx);
458	emit(A64_ADD(1, tmp, r2, tmp), ctx);
459	emit(A64_LSL(1, prg, r3, 3), ctx);
460	emit(A64_LDR64(prg, tmp, prg), ctx);
461	emit(A64_CBZ(1, prg, jmp_offset), ctx);
462
463	/* goto *(prog->bpf_func + prologue_offset); */
464	off = offsetof(struct bpf_prog, bpf_func);
465	emit_a64_mov_i64(tmp, off, ctx);
466	emit(A64_LDR64(tmp, prg, tmp), ctx);
467	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
468	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
469	emit(A64_BR(tmp), ctx);
470
471	/* out: */
472	if (out_offset == -1)
473		out_offset = cur_offset;
474	if (cur_offset != out_offset) {
475		pr_err_once("tail_call out_offset = %d, expected %d!\n",
476			    cur_offset, out_offset);
477		return -1;
478	}
479	return 0;
480#undef cur_offset
481#undef jmp_offset
482}
483
484#ifdef CONFIG_ARM64_LSE_ATOMICS
485static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
486{
487	const u8 code = insn->code;
488	const u8 dst = bpf2a64[insn->dst_reg];
489	const u8 src = bpf2a64[insn->src_reg];
490	const u8 tmp = bpf2a64[TMP_REG_1];
491	const u8 tmp2 = bpf2a64[TMP_REG_2];
492	const bool isdw = BPF_SIZE(code) == BPF_DW;
493	const s16 off = insn->off;
494	u8 reg;
495
496	if (!off) {
497		reg = dst;
498	} else {
499		emit_a64_mov_i(1, tmp, off, ctx);
500		emit(A64_ADD(1, tmp, tmp, dst), ctx);
501		reg = tmp;
502	}
503
504	switch (insn->imm) {
505	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
506	case BPF_ADD:
507		emit(A64_STADD(isdw, reg, src), ctx);
508		break;
509	case BPF_AND:
510		emit(A64_MVN(isdw, tmp2, src), ctx);
511		emit(A64_STCLR(isdw, reg, tmp2), ctx);
512		break;
513	case BPF_OR:
514		emit(A64_STSET(isdw, reg, src), ctx);
515		break;
516	case BPF_XOR:
517		emit(A64_STEOR(isdw, reg, src), ctx);
518		break;
519	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
520	case BPF_ADD | BPF_FETCH:
521		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
522		break;
523	case BPF_AND | BPF_FETCH:
524		emit(A64_MVN(isdw, tmp2, src), ctx);
525		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
526		break;
527	case BPF_OR | BPF_FETCH:
528		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
529		break;
530	case BPF_XOR | BPF_FETCH:
531		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
532		break;
533	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
534	case BPF_XCHG:
535		emit(A64_SWPAL(isdw, src, reg, src), ctx);
536		break;
537	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
538	case BPF_CMPXCHG:
539		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
540		break;
541	default:
542		pr_err_once("unknown atomic op code %02x\n", insn->imm);
543		return -EINVAL;
544	}
545
546	return 0;
547}
548#else
549static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
550{
551	return -EINVAL;
552}
553#endif
554
555static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
556{
557	const u8 code = insn->code;
558	const u8 dst = bpf2a64[insn->dst_reg];
559	const u8 src = bpf2a64[insn->src_reg];
560	const u8 tmp = bpf2a64[TMP_REG_1];
561	const u8 tmp2 = bpf2a64[TMP_REG_2];
562	const u8 tmp3 = bpf2a64[TMP_REG_3];
563	const int i = insn - ctx->prog->insnsi;
564	const s32 imm = insn->imm;
565	const s16 off = insn->off;
566	const bool isdw = BPF_SIZE(code) == BPF_DW;
567	u8 reg;
568	s32 jmp_offset;
569
570	if (!off) {
571		reg = dst;
572	} else {
573		emit_a64_mov_i(1, tmp, off, ctx);
574		emit(A64_ADD(1, tmp, tmp, dst), ctx);
575		reg = tmp;
576	}
577
578	if (imm == BPF_ADD || imm == BPF_AND ||
579	    imm == BPF_OR || imm == BPF_XOR) {
580		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
581		emit(A64_LDXR(isdw, tmp2, reg), ctx);
582		if (imm == BPF_ADD)
583			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
584		else if (imm == BPF_AND)
585			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
586		else if (imm == BPF_OR)
587			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
588		else
589			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
590		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
591		jmp_offset = -3;
592		check_imm19(jmp_offset);
593		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
594	} else if (imm == (BPF_ADD | BPF_FETCH) ||
595		   imm == (BPF_AND | BPF_FETCH) ||
596		   imm == (BPF_OR | BPF_FETCH) ||
597		   imm == (BPF_XOR | BPF_FETCH)) {
598		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
599		const u8 ax = bpf2a64[BPF_REG_AX];
600
601		emit(A64_MOV(isdw, ax, src), ctx);
602		emit(A64_LDXR(isdw, src, reg), ctx);
603		if (imm == (BPF_ADD | BPF_FETCH))
604			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
605		else if (imm == (BPF_AND | BPF_FETCH))
606			emit(A64_AND(isdw, tmp2, src, ax), ctx);
607		else if (imm == (BPF_OR | BPF_FETCH))
608			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
609		else
610			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
611		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
612		jmp_offset = -3;
613		check_imm19(jmp_offset);
614		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
615		emit(A64_DMB_ISH, ctx);
616	} else if (imm == BPF_XCHG) {
617		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
618		emit(A64_MOV(isdw, tmp2, src), ctx);
619		emit(A64_LDXR(isdw, src, reg), ctx);
620		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
621		jmp_offset = -2;
622		check_imm19(jmp_offset);
623		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
624		emit(A64_DMB_ISH, ctx);
625	} else if (imm == BPF_CMPXCHG) {
626		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
627		const u8 r0 = bpf2a64[BPF_REG_0];
628
629		emit(A64_MOV(isdw, tmp2, r0), ctx);
630		emit(A64_LDXR(isdw, r0, reg), ctx);
631		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
632		jmp_offset = 4;
633		check_imm19(jmp_offset);
634		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
635		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
636		jmp_offset = -4;
637		check_imm19(jmp_offset);
638		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
639		emit(A64_DMB_ISH, ctx);
640	} else {
641		pr_err_once("unknown atomic op code %02x\n", imm);
642		return -EINVAL;
643	}
644
645	return 0;
646}
647
648void dummy_tramp(void);
649
650asm (
651"	.pushsection .text, \"ax\", @progbits\n"
652"	.global dummy_tramp\n"
653"	.type dummy_tramp, %function\n"
654"dummy_tramp:"
655#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
656"	bti j\n" /* dummy_tramp is called via "br x10" */
657#endif
658"	mov x10, x30\n"
659"	mov x30, x9\n"
660"	ret x10\n"
661"	.size dummy_tramp, .-dummy_tramp\n"
662"	.popsection\n"
663);
664
665/* build a plt initialized like this:
666 *
667 * plt:
668 *      ldr tmp, target
669 *      br tmp
670 * target:
671 *      .quad dummy_tramp
672 *
673 * when a long jump trampoline is attached, target is filled with the
674 * trampoline address, and when the trampoline is removed, target is
675 * restored to dummy_tramp address.
676 */
677static void build_plt(struct jit_ctx *ctx)
678{
679	const u8 tmp = bpf2a64[TMP_REG_1];
680	struct bpf_plt *plt = NULL;
681
682	/* make sure target is 64-bit aligned */
683	if ((ctx->idx + PLT_TARGET_OFFSET / AARCH64_INSN_SIZE) % 2)
684		emit(A64_NOP, ctx);
685
686	plt = (struct bpf_plt *)(ctx->image + ctx->idx);
687	/* plt is called via bl, no BTI needed here */
688	emit(A64_LDR64LIT(tmp, 2 * AARCH64_INSN_SIZE), ctx);
689	emit(A64_BR(tmp), ctx);
690
691	if (ctx->image)
692		plt->target = (u64)&dummy_tramp;
693}
694
695static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb)
696{
697	const u8 r0 = bpf2a64[BPF_REG_0];
698	const u8 r6 = bpf2a64[BPF_REG_6];
699	const u8 r7 = bpf2a64[BPF_REG_7];
700	const u8 r8 = bpf2a64[BPF_REG_8];
701	const u8 r9 = bpf2a64[BPF_REG_9];
702	const u8 fp = bpf2a64[BPF_REG_FP];
703	const u8 fpb = bpf2a64[FP_BOTTOM];
704
705	/* We're done with BPF stack */
706	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
707
708	/*
709	 * Program acting as exception boundary pushes R23 and R24 in addition
710	 * to BPF callee-saved registers. Exception callback uses the boundary
711	 * program's stack frame, so recover these extra registers in the above
712	 * two cases.
713	 */
714	if (ctx->prog->aux->exception_boundary || is_exception_cb)
715		emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
716
717	/* Restore x27 and x28 */
718	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
719	/* Restore fs (x25) and x26 */
720	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
721
722	/* Restore callee-saved register */
723	emit(A64_POP(r8, r9, A64_SP), ctx);
724	emit(A64_POP(r6, r7, A64_SP), ctx);
725
726	/* Restore FP/LR registers */
727	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
728
729	/* Set return value */
730	emit(A64_MOV(1, A64_R(0), r0), ctx);
731
732	/* Authenticate lr */
733	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
734		emit(A64_AUTIASP, ctx);
735
736	emit(A64_RET(A64_LR), ctx);
737}
738
739#define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
740#define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
741
742bool ex_handler_bpf(const struct exception_table_entry *ex,
743		    struct pt_regs *regs)
744{
745	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
746	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
747
748	regs->regs[dst_reg] = 0;
749	regs->pc = (unsigned long)&ex->fixup - offset;
750	return true;
751}
752
753/* For accesses to BTF pointers, add an entry to the exception table */
754static int add_exception_handler(const struct bpf_insn *insn,
755				 struct jit_ctx *ctx,
756				 int dst_reg)
757{
758	off_t ins_offset;
759	off_t fixup_offset;
760	unsigned long pc;
761	struct exception_table_entry *ex;
762
763	if (!ctx->image)
764		/* First pass */
765		return 0;
766
767	if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
768		BPF_MODE(insn->code) != BPF_PROBE_MEMSX)
769		return 0;
770
771	if (!ctx->prog->aux->extable ||
772	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
773		return -EINVAL;
774
775	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
776	pc = (unsigned long)&ctx->ro_image[ctx->idx - 1];
777
778	/*
779	 * This is the relative offset of the instruction that may fault from
780	 * the exception table itself. This will be written to the exception
781	 * table and if this instruction faults, the destination register will
782	 * be set to '0' and the execution will jump to the next instruction.
783	 */
784	ins_offset = pc - (long)&ex->insn;
785	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
786		return -ERANGE;
787
788	/*
789	 * Since the extable follows the program, the fixup offset is always
790	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
791	 * to keep things simple, and put the destination register in the upper
792	 * bits. We don't need to worry about buildtime or runtime sort
793	 * modifying the upper bits because the table is already sorted, and
794	 * isn't part of the main exception table.
795	 *
796	 * The fixup_offset is set to the next instruction from the instruction
797	 * that may fault. The execution will jump to this after handling the
798	 * fault.
799	 */
800	fixup_offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
801	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
802		return -ERANGE;
803
804	/*
805	 * The offsets above have been calculated using the RO buffer but we
806	 * need to use the R/W buffer for writes.
807	 * switch ex to rw buffer for writing.
808	 */
809	ex = (void *)ctx->image + ((void *)ex - (void *)ctx->ro_image);
810
811	ex->insn = ins_offset;
812
813	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
814		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
815
816	ex->type = EX_TYPE_BPF;
817
818	ctx->exentry_idx++;
819	return 0;
820}
821
822/* JITs an eBPF instruction.
823 * Returns:
824 * 0  - successfully JITed an 8-byte eBPF instruction.
825 * >0 - successfully JITed a 16-byte eBPF instruction.
826 * <0 - failed to JIT.
827 */
828static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
829		      bool extra_pass)
830{
831	const u8 code = insn->code;
832	const u8 dst = bpf2a64[insn->dst_reg];
833	const u8 src = bpf2a64[insn->src_reg];
834	const u8 tmp = bpf2a64[TMP_REG_1];
835	const u8 tmp2 = bpf2a64[TMP_REG_2];
836	const u8 fp = bpf2a64[BPF_REG_FP];
837	const u8 fpb = bpf2a64[FP_BOTTOM];
838	const s16 off = insn->off;
839	const s32 imm = insn->imm;
840	const int i = insn - ctx->prog->insnsi;
841	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
842			  BPF_CLASS(code) == BPF_JMP;
843	u8 jmp_cond;
844	s32 jmp_offset;
845	u32 a64_insn;
846	u8 src_adj;
847	u8 dst_adj;
848	int off_adj;
849	int ret;
850	bool sign_extend;
851
852	switch (code) {
853	/* dst = src */
854	case BPF_ALU | BPF_MOV | BPF_X:
855	case BPF_ALU64 | BPF_MOV | BPF_X:
856		switch (insn->off) {
857		case 0:
858			emit(A64_MOV(is64, dst, src), ctx);
859			break;
860		case 8:
861			emit(A64_SXTB(is64, dst, src), ctx);
862			break;
863		case 16:
864			emit(A64_SXTH(is64, dst, src), ctx);
865			break;
866		case 32:
867			emit(A64_SXTW(is64, dst, src), ctx);
868			break;
869		}
870		break;
871	/* dst = dst OP src */
872	case BPF_ALU | BPF_ADD | BPF_X:
873	case BPF_ALU64 | BPF_ADD | BPF_X:
874		emit(A64_ADD(is64, dst, dst, src), ctx);
875		break;
876	case BPF_ALU | BPF_SUB | BPF_X:
877	case BPF_ALU64 | BPF_SUB | BPF_X:
878		emit(A64_SUB(is64, dst, dst, src), ctx);
879		break;
880	case BPF_ALU | BPF_AND | BPF_X:
881	case BPF_ALU64 | BPF_AND | BPF_X:
882		emit(A64_AND(is64, dst, dst, src), ctx);
883		break;
884	case BPF_ALU | BPF_OR | BPF_X:
885	case BPF_ALU64 | BPF_OR | BPF_X:
886		emit(A64_ORR(is64, dst, dst, src), ctx);
887		break;
888	case BPF_ALU | BPF_XOR | BPF_X:
889	case BPF_ALU64 | BPF_XOR | BPF_X:
890		emit(A64_EOR(is64, dst, dst, src), ctx);
891		break;
892	case BPF_ALU | BPF_MUL | BPF_X:
893	case BPF_ALU64 | BPF_MUL | BPF_X:
894		emit(A64_MUL(is64, dst, dst, src), ctx);
895		break;
896	case BPF_ALU | BPF_DIV | BPF_X:
897	case BPF_ALU64 | BPF_DIV | BPF_X:
898		if (!off)
899			emit(A64_UDIV(is64, dst, dst, src), ctx);
900		else
901			emit(A64_SDIV(is64, dst, dst, src), ctx);
902		break;
903	case BPF_ALU | BPF_MOD | BPF_X:
904	case BPF_ALU64 | BPF_MOD | BPF_X:
905		if (!off)
906			emit(A64_UDIV(is64, tmp, dst, src), ctx);
907		else
908			emit(A64_SDIV(is64, tmp, dst, src), ctx);
909		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
910		break;
911	case BPF_ALU | BPF_LSH | BPF_X:
912	case BPF_ALU64 | BPF_LSH | BPF_X:
913		emit(A64_LSLV(is64, dst, dst, src), ctx);
914		break;
915	case BPF_ALU | BPF_RSH | BPF_X:
916	case BPF_ALU64 | BPF_RSH | BPF_X:
917		emit(A64_LSRV(is64, dst, dst, src), ctx);
918		break;
919	case BPF_ALU | BPF_ARSH | BPF_X:
920	case BPF_ALU64 | BPF_ARSH | BPF_X:
921		emit(A64_ASRV(is64, dst, dst, src), ctx);
922		break;
923	/* dst = -dst */
924	case BPF_ALU | BPF_NEG:
925	case BPF_ALU64 | BPF_NEG:
926		emit(A64_NEG(is64, dst, dst), ctx);
927		break;
928	/* dst = BSWAP##imm(dst) */
929	case BPF_ALU | BPF_END | BPF_FROM_LE:
930	case BPF_ALU | BPF_END | BPF_FROM_BE:
931	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
932#ifdef CONFIG_CPU_BIG_ENDIAN
933		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_BE)
934			goto emit_bswap_uxt;
935#else /* !CONFIG_CPU_BIG_ENDIAN */
936		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_LE)
937			goto emit_bswap_uxt;
938#endif
939		switch (imm) {
940		case 16:
941			emit(A64_REV16(is64, dst, dst), ctx);
942			/* zero-extend 16 bits into 64 bits */
943			emit(A64_UXTH(is64, dst, dst), ctx);
944			break;
945		case 32:
946			emit(A64_REV32(0, dst, dst), ctx);
947			/* upper 32 bits already cleared */
948			break;
949		case 64:
950			emit(A64_REV64(dst, dst), ctx);
951			break;
952		}
953		break;
954emit_bswap_uxt:
955		switch (imm) {
956		case 16:
957			/* zero-extend 16 bits into 64 bits */
958			emit(A64_UXTH(is64, dst, dst), ctx);
959			break;
960		case 32:
961			/* zero-extend 32 bits into 64 bits */
962			emit(A64_UXTW(is64, dst, dst), ctx);
963			break;
964		case 64:
965			/* nop */
966			break;
967		}
968		break;
969	/* dst = imm */
970	case BPF_ALU | BPF_MOV | BPF_K:
971	case BPF_ALU64 | BPF_MOV | BPF_K:
972		emit_a64_mov_i(is64, dst, imm, ctx);
973		break;
974	/* dst = dst OP imm */
975	case BPF_ALU | BPF_ADD | BPF_K:
976	case BPF_ALU64 | BPF_ADD | BPF_K:
977		if (is_addsub_imm(imm)) {
978			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
979		} else if (is_addsub_imm(-imm)) {
980			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
981		} else {
982			emit_a64_mov_i(is64, tmp, imm, ctx);
983			emit(A64_ADD(is64, dst, dst, tmp), ctx);
984		}
985		break;
986	case BPF_ALU | BPF_SUB | BPF_K:
987	case BPF_ALU64 | BPF_SUB | BPF_K:
988		if (is_addsub_imm(imm)) {
989			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
990		} else if (is_addsub_imm(-imm)) {
991			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
992		} else {
993			emit_a64_mov_i(is64, tmp, imm, ctx);
994			emit(A64_SUB(is64, dst, dst, tmp), ctx);
995		}
996		break;
997	case BPF_ALU | BPF_AND | BPF_K:
998	case BPF_ALU64 | BPF_AND | BPF_K:
999		a64_insn = A64_AND_I(is64, dst, dst, imm);
1000		if (a64_insn != AARCH64_BREAK_FAULT) {
1001			emit(a64_insn, ctx);
1002		} else {
1003			emit_a64_mov_i(is64, tmp, imm, ctx);
1004			emit(A64_AND(is64, dst, dst, tmp), ctx);
1005		}
1006		break;
1007	case BPF_ALU | BPF_OR | BPF_K:
1008	case BPF_ALU64 | BPF_OR | BPF_K:
1009		a64_insn = A64_ORR_I(is64, dst, dst, imm);
1010		if (a64_insn != AARCH64_BREAK_FAULT) {
1011			emit(a64_insn, ctx);
1012		} else {
1013			emit_a64_mov_i(is64, tmp, imm, ctx);
1014			emit(A64_ORR(is64, dst, dst, tmp), ctx);
1015		}
1016		break;
1017	case BPF_ALU | BPF_XOR | BPF_K:
1018	case BPF_ALU64 | BPF_XOR | BPF_K:
1019		a64_insn = A64_EOR_I(is64, dst, dst, imm);
1020		if (a64_insn != AARCH64_BREAK_FAULT) {
1021			emit(a64_insn, ctx);
1022		} else {
1023			emit_a64_mov_i(is64, tmp, imm, ctx);
1024			emit(A64_EOR(is64, dst, dst, tmp), ctx);
1025		}
1026		break;
1027	case BPF_ALU | BPF_MUL | BPF_K:
1028	case BPF_ALU64 | BPF_MUL | BPF_K:
1029		emit_a64_mov_i(is64, tmp, imm, ctx);
1030		emit(A64_MUL(is64, dst, dst, tmp), ctx);
1031		break;
1032	case BPF_ALU | BPF_DIV | BPF_K:
1033	case BPF_ALU64 | BPF_DIV | BPF_K:
1034		emit_a64_mov_i(is64, tmp, imm, ctx);
1035		if (!off)
1036			emit(A64_UDIV(is64, dst, dst, tmp), ctx);
1037		else
1038			emit(A64_SDIV(is64, dst, dst, tmp), ctx);
1039		break;
1040	case BPF_ALU | BPF_MOD | BPF_K:
1041	case BPF_ALU64 | BPF_MOD | BPF_K:
1042		emit_a64_mov_i(is64, tmp2, imm, ctx);
1043		if (!off)
1044			emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
1045		else
1046			emit(A64_SDIV(is64, tmp, dst, tmp2), ctx);
1047		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
1048		break;
1049	case BPF_ALU | BPF_LSH | BPF_K:
1050	case BPF_ALU64 | BPF_LSH | BPF_K:
1051		emit(A64_LSL(is64, dst, dst, imm), ctx);
1052		break;
1053	case BPF_ALU | BPF_RSH | BPF_K:
1054	case BPF_ALU64 | BPF_RSH | BPF_K:
1055		emit(A64_LSR(is64, dst, dst, imm), ctx);
1056		break;
1057	case BPF_ALU | BPF_ARSH | BPF_K:
1058	case BPF_ALU64 | BPF_ARSH | BPF_K:
1059		emit(A64_ASR(is64, dst, dst, imm), ctx);
1060		break;
1061
1062	/* JUMP off */
1063	case BPF_JMP | BPF_JA:
1064	case BPF_JMP32 | BPF_JA:
1065		if (BPF_CLASS(code) == BPF_JMP)
1066			jmp_offset = bpf2a64_offset(i, off, ctx);
1067		else
1068			jmp_offset = bpf2a64_offset(i, imm, ctx);
1069		check_imm26(jmp_offset);
1070		emit(A64_B(jmp_offset), ctx);
1071		break;
1072	/* IF (dst COND src) JUMP off */
1073	case BPF_JMP | BPF_JEQ | BPF_X:
1074	case BPF_JMP | BPF_JGT | BPF_X:
1075	case BPF_JMP | BPF_JLT | BPF_X:
1076	case BPF_JMP | BPF_JGE | BPF_X:
1077	case BPF_JMP | BPF_JLE | BPF_X:
1078	case BPF_JMP | BPF_JNE | BPF_X:
1079	case BPF_JMP | BPF_JSGT | BPF_X:
1080	case BPF_JMP | BPF_JSLT | BPF_X:
1081	case BPF_JMP | BPF_JSGE | BPF_X:
1082	case BPF_JMP | BPF_JSLE | BPF_X:
1083	case BPF_JMP32 | BPF_JEQ | BPF_X:
1084	case BPF_JMP32 | BPF_JGT | BPF_X:
1085	case BPF_JMP32 | BPF_JLT | BPF_X:
1086	case BPF_JMP32 | BPF_JGE | BPF_X:
1087	case BPF_JMP32 | BPF_JLE | BPF_X:
1088	case BPF_JMP32 | BPF_JNE | BPF_X:
1089	case BPF_JMP32 | BPF_JSGT | BPF_X:
1090	case BPF_JMP32 | BPF_JSLT | BPF_X:
1091	case BPF_JMP32 | BPF_JSGE | BPF_X:
1092	case BPF_JMP32 | BPF_JSLE | BPF_X:
1093		emit(A64_CMP(is64, dst, src), ctx);
1094emit_cond_jmp:
1095		jmp_offset = bpf2a64_offset(i, off, ctx);
1096		check_imm19(jmp_offset);
1097		switch (BPF_OP(code)) {
1098		case BPF_JEQ:
1099			jmp_cond = A64_COND_EQ;
1100			break;
1101		case BPF_JGT:
1102			jmp_cond = A64_COND_HI;
1103			break;
1104		case BPF_JLT:
1105			jmp_cond = A64_COND_CC;
1106			break;
1107		case BPF_JGE:
1108			jmp_cond = A64_COND_CS;
1109			break;
1110		case BPF_JLE:
1111			jmp_cond = A64_COND_LS;
1112			break;
1113		case BPF_JSET:
1114		case BPF_JNE:
1115			jmp_cond = A64_COND_NE;
1116			break;
1117		case BPF_JSGT:
1118			jmp_cond = A64_COND_GT;
1119			break;
1120		case BPF_JSLT:
1121			jmp_cond = A64_COND_LT;
1122			break;
1123		case BPF_JSGE:
1124			jmp_cond = A64_COND_GE;
1125			break;
1126		case BPF_JSLE:
1127			jmp_cond = A64_COND_LE;
1128			break;
1129		default:
1130			return -EFAULT;
1131		}
1132		emit(A64_B_(jmp_cond, jmp_offset), ctx);
1133		break;
1134	case BPF_JMP | BPF_JSET | BPF_X:
1135	case BPF_JMP32 | BPF_JSET | BPF_X:
1136		emit(A64_TST(is64, dst, src), ctx);
1137		goto emit_cond_jmp;
1138	/* IF (dst COND imm) JUMP off */
1139	case BPF_JMP | BPF_JEQ | BPF_K:
1140	case BPF_JMP | BPF_JGT | BPF_K:
1141	case BPF_JMP | BPF_JLT | BPF_K:
1142	case BPF_JMP | BPF_JGE | BPF_K:
1143	case BPF_JMP | BPF_JLE | BPF_K:
1144	case BPF_JMP | BPF_JNE | BPF_K:
1145	case BPF_JMP | BPF_JSGT | BPF_K:
1146	case BPF_JMP | BPF_JSLT | BPF_K:
1147	case BPF_JMP | BPF_JSGE | BPF_K:
1148	case BPF_JMP | BPF_JSLE | BPF_K:
1149	case BPF_JMP32 | BPF_JEQ | BPF_K:
1150	case BPF_JMP32 | BPF_JGT | BPF_K:
1151	case BPF_JMP32 | BPF_JLT | BPF_K:
1152	case BPF_JMP32 | BPF_JGE | BPF_K:
1153	case BPF_JMP32 | BPF_JLE | BPF_K:
1154	case BPF_JMP32 | BPF_JNE | BPF_K:
1155	case BPF_JMP32 | BPF_JSGT | BPF_K:
1156	case BPF_JMP32 | BPF_JSLT | BPF_K:
1157	case BPF_JMP32 | BPF_JSGE | BPF_K:
1158	case BPF_JMP32 | BPF_JSLE | BPF_K:
1159		if (is_addsub_imm(imm)) {
1160			emit(A64_CMP_I(is64, dst, imm), ctx);
1161		} else if (is_addsub_imm(-imm)) {
1162			emit(A64_CMN_I(is64, dst, -imm), ctx);
1163		} else {
1164			emit_a64_mov_i(is64, tmp, imm, ctx);
1165			emit(A64_CMP(is64, dst, tmp), ctx);
1166		}
1167		goto emit_cond_jmp;
1168	case BPF_JMP | BPF_JSET | BPF_K:
1169	case BPF_JMP32 | BPF_JSET | BPF_K:
1170		a64_insn = A64_TST_I(is64, dst, imm);
1171		if (a64_insn != AARCH64_BREAK_FAULT) {
1172			emit(a64_insn, ctx);
1173		} else {
1174			emit_a64_mov_i(is64, tmp, imm, ctx);
1175			emit(A64_TST(is64, dst, tmp), ctx);
1176		}
1177		goto emit_cond_jmp;
1178	/* function call */
1179	case BPF_JMP | BPF_CALL:
1180	{
1181		const u8 r0 = bpf2a64[BPF_REG_0];
1182		bool func_addr_fixed;
1183		u64 func_addr;
1184
1185		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1186					    &func_addr, &func_addr_fixed);
1187		if (ret < 0)
1188			return ret;
1189		emit_call(func_addr, ctx);
1190		emit(A64_MOV(1, r0, A64_R(0)), ctx);
1191		break;
1192	}
1193	/* tail call */
1194	case BPF_JMP | BPF_TAIL_CALL:
1195		if (emit_bpf_tail_call(ctx))
1196			return -EFAULT;
1197		break;
1198	/* function return */
1199	case BPF_JMP | BPF_EXIT:
1200		/* Optimization: when last instruction is EXIT,
1201		   simply fallthrough to epilogue. */
1202		if (i == ctx->prog->len - 1)
1203			break;
1204		jmp_offset = epilogue_offset(ctx);
1205		check_imm26(jmp_offset);
1206		emit(A64_B(jmp_offset), ctx);
1207		break;
1208
1209	/* dst = imm64 */
1210	case BPF_LD | BPF_IMM | BPF_DW:
1211	{
1212		const struct bpf_insn insn1 = insn[1];
1213		u64 imm64;
1214
1215		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1216		if (bpf_pseudo_func(insn))
1217			emit_addr_mov_i64(dst, imm64, ctx);
1218		else
1219			emit_a64_mov_i64(dst, imm64, ctx);
1220
1221		return 1;
1222	}
1223
1224	/* LDX: dst = (u64)*(unsigned size *)(src + off) */
1225	case BPF_LDX | BPF_MEM | BPF_W:
1226	case BPF_LDX | BPF_MEM | BPF_H:
1227	case BPF_LDX | BPF_MEM | BPF_B:
1228	case BPF_LDX | BPF_MEM | BPF_DW:
1229	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1230	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1231	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1232	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1233	/* LDXS: dst_reg = (s64)*(signed size *)(src_reg + off) */
1234	case BPF_LDX | BPF_MEMSX | BPF_B:
1235	case BPF_LDX | BPF_MEMSX | BPF_H:
1236	case BPF_LDX | BPF_MEMSX | BPF_W:
1237	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1238	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1239	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1240		if (ctx->fpb_offset > 0 && src == fp) {
1241			src_adj = fpb;
1242			off_adj = off + ctx->fpb_offset;
1243		} else {
1244			src_adj = src;
1245			off_adj = off;
1246		}
1247		sign_extend = (BPF_MODE(insn->code) == BPF_MEMSX ||
1248				BPF_MODE(insn->code) == BPF_PROBE_MEMSX);
1249		switch (BPF_SIZE(code)) {
1250		case BPF_W:
1251			if (is_lsi_offset(off_adj, 2)) {
1252				if (sign_extend)
1253					emit(A64_LDRSWI(dst, src_adj, off_adj), ctx);
1254				else
1255					emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1256			} else {
1257				emit_a64_mov_i(1, tmp, off, ctx);
1258				if (sign_extend)
1259					emit(A64_LDRSW(dst, src, tmp), ctx);
1260				else
1261					emit(A64_LDR32(dst, src, tmp), ctx);
1262			}
1263			break;
1264		case BPF_H:
1265			if (is_lsi_offset(off_adj, 1)) {
1266				if (sign_extend)
1267					emit(A64_LDRSHI(dst, src_adj, off_adj), ctx);
1268				else
1269					emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1270			} else {
1271				emit_a64_mov_i(1, tmp, off, ctx);
1272				if (sign_extend)
1273					emit(A64_LDRSH(dst, src, tmp), ctx);
1274				else
1275					emit(A64_LDRH(dst, src, tmp), ctx);
1276			}
1277			break;
1278		case BPF_B:
1279			if (is_lsi_offset(off_adj, 0)) {
1280				if (sign_extend)
1281					emit(A64_LDRSBI(dst, src_adj, off_adj), ctx);
1282				else
1283					emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1284			} else {
1285				emit_a64_mov_i(1, tmp, off, ctx);
1286				if (sign_extend)
1287					emit(A64_LDRSB(dst, src, tmp), ctx);
1288				else
1289					emit(A64_LDRB(dst, src, tmp), ctx);
1290			}
1291			break;
1292		case BPF_DW:
1293			if (is_lsi_offset(off_adj, 3)) {
1294				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1295			} else {
1296				emit_a64_mov_i(1, tmp, off, ctx);
1297				emit(A64_LDR64(dst, src, tmp), ctx);
1298			}
1299			break;
1300		}
1301
1302		ret = add_exception_handler(insn, ctx, dst);
1303		if (ret)
1304			return ret;
1305		break;
1306
1307	/* speculation barrier */
1308	case BPF_ST | BPF_NOSPEC:
1309		/*
1310		 * Nothing required here.
1311		 *
1312		 * In case of arm64, we rely on the firmware mitigation of
1313		 * Speculative Store Bypass as controlled via the ssbd kernel
1314		 * parameter. Whenever the mitigation is enabled, it works
1315		 * for all of the kernel code with no need to provide any
1316		 * additional instructions.
1317		 */
1318		break;
1319
1320	/* ST: *(size *)(dst + off) = imm */
1321	case BPF_ST | BPF_MEM | BPF_W:
1322	case BPF_ST | BPF_MEM | BPF_H:
1323	case BPF_ST | BPF_MEM | BPF_B:
1324	case BPF_ST | BPF_MEM | BPF_DW:
1325		if (ctx->fpb_offset > 0 && dst == fp) {
1326			dst_adj = fpb;
1327			off_adj = off + ctx->fpb_offset;
1328		} else {
1329			dst_adj = dst;
1330			off_adj = off;
1331		}
1332		/* Load imm to a register then store it */
1333		emit_a64_mov_i(1, tmp, imm, ctx);
1334		switch (BPF_SIZE(code)) {
1335		case BPF_W:
1336			if (is_lsi_offset(off_adj, 2)) {
1337				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1338			} else {
1339				emit_a64_mov_i(1, tmp2, off, ctx);
1340				emit(A64_STR32(tmp, dst, tmp2), ctx);
1341			}
1342			break;
1343		case BPF_H:
1344			if (is_lsi_offset(off_adj, 1)) {
1345				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1346			} else {
1347				emit_a64_mov_i(1, tmp2, off, ctx);
1348				emit(A64_STRH(tmp, dst, tmp2), ctx);
1349			}
1350			break;
1351		case BPF_B:
1352			if (is_lsi_offset(off_adj, 0)) {
1353				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1354			} else {
1355				emit_a64_mov_i(1, tmp2, off, ctx);
1356				emit(A64_STRB(tmp, dst, tmp2), ctx);
1357			}
1358			break;
1359		case BPF_DW:
1360			if (is_lsi_offset(off_adj, 3)) {
1361				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1362			} else {
1363				emit_a64_mov_i(1, tmp2, off, ctx);
1364				emit(A64_STR64(tmp, dst, tmp2), ctx);
1365			}
1366			break;
1367		}
1368		break;
1369
1370	/* STX: *(size *)(dst + off) = src */
1371	case BPF_STX | BPF_MEM | BPF_W:
1372	case BPF_STX | BPF_MEM | BPF_H:
1373	case BPF_STX | BPF_MEM | BPF_B:
1374	case BPF_STX | BPF_MEM | BPF_DW:
1375		if (ctx->fpb_offset > 0 && dst == fp) {
1376			dst_adj = fpb;
1377			off_adj = off + ctx->fpb_offset;
1378		} else {
1379			dst_adj = dst;
1380			off_adj = off;
1381		}
1382		switch (BPF_SIZE(code)) {
1383		case BPF_W:
1384			if (is_lsi_offset(off_adj, 2)) {
1385				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1386			} else {
1387				emit_a64_mov_i(1, tmp, off, ctx);
1388				emit(A64_STR32(src, dst, tmp), ctx);
1389			}
1390			break;
1391		case BPF_H:
1392			if (is_lsi_offset(off_adj, 1)) {
1393				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1394			} else {
1395				emit_a64_mov_i(1, tmp, off, ctx);
1396				emit(A64_STRH(src, dst, tmp), ctx);
1397			}
1398			break;
1399		case BPF_B:
1400			if (is_lsi_offset(off_adj, 0)) {
1401				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1402			} else {
1403				emit_a64_mov_i(1, tmp, off, ctx);
1404				emit(A64_STRB(src, dst, tmp), ctx);
1405			}
1406			break;
1407		case BPF_DW:
1408			if (is_lsi_offset(off_adj, 3)) {
1409				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1410			} else {
1411				emit_a64_mov_i(1, tmp, off, ctx);
1412				emit(A64_STR64(src, dst, tmp), ctx);
1413			}
1414			break;
1415		}
1416		break;
1417
1418	case BPF_STX | BPF_ATOMIC | BPF_W:
1419	case BPF_STX | BPF_ATOMIC | BPF_DW:
1420		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1421			ret = emit_lse_atomic(insn, ctx);
1422		else
1423			ret = emit_ll_sc_atomic(insn, ctx);
1424		if (ret)
1425			return ret;
1426		break;
1427
1428	default:
1429		pr_err_once("unknown opcode %02x\n", code);
1430		return -EINVAL;
1431	}
1432
1433	return 0;
1434}
1435
1436/*
1437 * Return 0 if FP may change at runtime, otherwise find the minimum negative
1438 * offset to FP, converts it to positive number, and align down to 8 bytes.
1439 */
1440static int find_fpb_offset(struct bpf_prog *prog)
1441{
1442	int i;
1443	int offset = 0;
1444
1445	for (i = 0; i < prog->len; i++) {
1446		const struct bpf_insn *insn = &prog->insnsi[i];
1447		const u8 class = BPF_CLASS(insn->code);
1448		const u8 mode = BPF_MODE(insn->code);
1449		const u8 src = insn->src_reg;
1450		const u8 dst = insn->dst_reg;
1451		const s32 imm = insn->imm;
1452		const s16 off = insn->off;
1453
1454		switch (class) {
1455		case BPF_STX:
1456		case BPF_ST:
1457			/* fp holds atomic operation result */
1458			if (class == BPF_STX && mode == BPF_ATOMIC &&
1459			    ((imm == BPF_XCHG ||
1460			      imm == (BPF_FETCH | BPF_ADD) ||
1461			      imm == (BPF_FETCH | BPF_AND) ||
1462			      imm == (BPF_FETCH | BPF_XOR) ||
1463			      imm == (BPF_FETCH | BPF_OR)) &&
1464			     src == BPF_REG_FP))
1465				return 0;
1466
1467			if (mode == BPF_MEM && dst == BPF_REG_FP &&
1468			    off < offset)
1469				offset = insn->off;
1470			break;
1471
1472		case BPF_JMP32:
1473		case BPF_JMP:
1474			break;
1475
1476		case BPF_LDX:
1477		case BPF_LD:
1478			/* fp holds load result */
1479			if (dst == BPF_REG_FP)
1480				return 0;
1481
1482			if (class == BPF_LDX && mode == BPF_MEM &&
1483			    src == BPF_REG_FP && off < offset)
1484				offset = off;
1485			break;
1486
1487		case BPF_ALU:
1488		case BPF_ALU64:
1489		default:
1490			/* fp holds ALU result */
1491			if (dst == BPF_REG_FP)
1492				return 0;
1493		}
1494	}
1495
1496	if (offset < 0) {
1497		/*
1498		 * safely be converted to a positive 'int', since insn->off
1499		 * is 's16'
1500		 */
1501		offset = -offset;
1502		/* align down to 8 bytes */
1503		offset = ALIGN_DOWN(offset, 8);
1504	}
1505
1506	return offset;
1507}
1508
1509static int build_body(struct jit_ctx *ctx, bool extra_pass)
1510{
1511	const struct bpf_prog *prog = ctx->prog;
1512	int i;
1513
1514	/*
1515	 * - offset[0] offset of the end of prologue,
1516	 *   start of the 1st instruction.
1517	 * - offset[1] - offset of the end of 1st instruction,
1518	 *   start of the 2nd instruction
1519	 * [....]
1520	 * - offset[3] - offset of the end of 3rd instruction,
1521	 *   start of 4th instruction
1522	 */
1523	for (i = 0; i < prog->len; i++) {
1524		const struct bpf_insn *insn = &prog->insnsi[i];
1525		int ret;
1526
1527		if (ctx->image == NULL)
1528			ctx->offset[i] = ctx->idx;
1529		ret = build_insn(insn, ctx, extra_pass);
1530		if (ret > 0) {
1531			i++;
1532			if (ctx->image == NULL)
1533				ctx->offset[i] = ctx->idx;
1534			continue;
1535		}
1536		if (ret)
1537			return ret;
1538	}
1539	/*
1540	 * offset is allocated with prog->len + 1 so fill in
1541	 * the last element with the offset after the last
1542	 * instruction (end of program)
1543	 */
1544	if (ctx->image == NULL)
1545		ctx->offset[i] = ctx->idx;
1546
1547	return 0;
1548}
1549
1550static int validate_code(struct jit_ctx *ctx)
1551{
1552	int i;
1553
1554	for (i = 0; i < ctx->idx; i++) {
1555		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1556
1557		if (a64_insn == AARCH64_BREAK_FAULT)
1558			return -1;
1559	}
1560	return 0;
1561}
1562
1563static int validate_ctx(struct jit_ctx *ctx)
1564{
1565	if (validate_code(ctx))
1566		return -1;
1567
1568	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1569		return -1;
1570
1571	return 0;
1572}
1573
1574static inline void bpf_flush_icache(void *start, void *end)
1575{
1576	flush_icache_range((unsigned long)start, (unsigned long)end);
1577}
1578
1579struct arm64_jit_data {
1580	struct bpf_binary_header *header;
1581	u8 *ro_image;
1582	struct bpf_binary_header *ro_header;
1583	struct jit_ctx ctx;
1584};
1585
1586struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1587{
1588	int image_size, prog_size, extable_size, extable_align, extable_offset;
1589	struct bpf_prog *tmp, *orig_prog = prog;
1590	struct bpf_binary_header *header;
1591	struct bpf_binary_header *ro_header;
1592	struct arm64_jit_data *jit_data;
1593	bool was_classic = bpf_prog_was_classic(prog);
1594	bool tmp_blinded = false;
1595	bool extra_pass = false;
1596	struct jit_ctx ctx;
1597	u8 *image_ptr;
1598	u8 *ro_image_ptr;
1599
1600	if (!prog->jit_requested)
1601		return orig_prog;
1602
1603	tmp = bpf_jit_blind_constants(prog);
1604	/* If blinding was requested and we failed during blinding,
1605	 * we must fall back to the interpreter.
1606	 */
1607	if (IS_ERR(tmp))
1608		return orig_prog;
1609	if (tmp != prog) {
1610		tmp_blinded = true;
1611		prog = tmp;
1612	}
1613
1614	jit_data = prog->aux->jit_data;
1615	if (!jit_data) {
1616		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1617		if (!jit_data) {
1618			prog = orig_prog;
1619			goto out;
1620		}
1621		prog->aux->jit_data = jit_data;
1622	}
1623	if (jit_data->ctx.offset) {
1624		ctx = jit_data->ctx;
1625		ro_image_ptr = jit_data->ro_image;
1626		ro_header = jit_data->ro_header;
1627		header = jit_data->header;
1628		image_ptr = (void *)header + ((void *)ro_image_ptr
1629						 - (void *)ro_header);
1630		extra_pass = true;
1631		prog_size = sizeof(u32) * ctx.idx;
1632		goto skip_init_ctx;
1633	}
1634	memset(&ctx, 0, sizeof(ctx));
1635	ctx.prog = prog;
1636
1637	ctx.offset = kvcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1638	if (ctx.offset == NULL) {
1639		prog = orig_prog;
1640		goto out_off;
1641	}
1642
1643	ctx.fpb_offset = find_fpb_offset(prog);
1644
1645	/*
1646	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1647	 *
1648	 * BPF line info needs ctx->offset[i] to be the offset of
1649	 * instruction[i] in jited image, so build prologue first.
1650	 */
1651	if (build_prologue(&ctx, was_classic, prog->aux->exception_cb)) {
1652		prog = orig_prog;
1653		goto out_off;
1654	}
1655
1656	if (build_body(&ctx, extra_pass)) {
1657		prog = orig_prog;
1658		goto out_off;
1659	}
1660
1661	ctx.epilogue_offset = ctx.idx;
1662	build_epilogue(&ctx, prog->aux->exception_cb);
1663	build_plt(&ctx);
1664
1665	extable_align = __alignof__(struct exception_table_entry);
1666	extable_size = prog->aux->num_exentries *
1667		sizeof(struct exception_table_entry);
1668
1669	/* Now we know the actual image size. */
1670	prog_size = sizeof(u32) * ctx.idx;
1671	/* also allocate space for plt target */
1672	extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align);
1673	image_size = extable_offset + extable_size;
1674	ro_header = bpf_jit_binary_pack_alloc(image_size, &ro_image_ptr,
1675					      sizeof(u32), &header, &image_ptr,
1676					      jit_fill_hole);
1677	if (!ro_header) {
1678		prog = orig_prog;
1679		goto out_off;
1680	}
1681
1682	/* 2. Now, the actual pass. */
1683
1684	/*
1685	 * Use the image(RW) for writing the JITed instructions. But also save
1686	 * the ro_image(RX) for calculating the offsets in the image. The RW
1687	 * image will be later copied to the RX image from where the program
1688	 * will run. The bpf_jit_binary_pack_finalize() will do this copy in the
1689	 * final step.
1690	 */
1691	ctx.image = (__le32 *)image_ptr;
1692	ctx.ro_image = (__le32 *)ro_image_ptr;
1693	if (extable_size)
1694		prog->aux->extable = (void *)ro_image_ptr + extable_offset;
1695skip_init_ctx:
1696	ctx.idx = 0;
1697	ctx.exentry_idx = 0;
1698
1699	build_prologue(&ctx, was_classic, prog->aux->exception_cb);
1700
1701	if (build_body(&ctx, extra_pass)) {
1702		prog = orig_prog;
1703		goto out_free_hdr;
1704	}
1705
1706	build_epilogue(&ctx, prog->aux->exception_cb);
1707	build_plt(&ctx);
1708
1709	/* 3. Extra pass to validate JITed code. */
1710	if (validate_ctx(&ctx)) {
1711		prog = orig_prog;
1712		goto out_free_hdr;
1713	}
1714
1715	/* And we're done. */
1716	if (bpf_jit_enable > 1)
1717		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1718
1719	if (!prog->is_func || extra_pass) {
1720		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1721			pr_err_once("multi-func JIT bug %d != %d\n",
1722				    ctx.idx, jit_data->ctx.idx);
1723			prog->bpf_func = NULL;
1724			prog->jited = 0;
1725			prog->jited_len = 0;
1726			goto out_free_hdr;
1727		}
1728		if (WARN_ON(bpf_jit_binary_pack_finalize(prog, ro_header,
1729							 header))) {
1730			/* ro_header has been freed */
1731			ro_header = NULL;
1732			prog = orig_prog;
1733			goto out_off;
1734		}
1735		/*
1736		 * The instructions have now been copied to the ROX region from
1737		 * where they will execute. Now the data cache has to be cleaned to
1738		 * the PoU and the I-cache has to be invalidated for the VAs.
1739		 */
1740		bpf_flush_icache(ro_header, ctx.ro_image + ctx.idx);
1741	} else {
1742		jit_data->ctx = ctx;
1743		jit_data->ro_image = ro_image_ptr;
1744		jit_data->header = header;
1745		jit_data->ro_header = ro_header;
1746	}
1747
1748	prog->bpf_func = (void *)ctx.ro_image;
1749	prog->jited = 1;
1750	prog->jited_len = prog_size;
1751
1752	if (!prog->is_func || extra_pass) {
1753		int i;
1754
1755		/* offset[prog->len] is the size of program */
1756		for (i = 0; i <= prog->len; i++)
1757			ctx.offset[i] *= AARCH64_INSN_SIZE;
1758		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1759out_off:
1760		kvfree(ctx.offset);
1761		kfree(jit_data);
1762		prog->aux->jit_data = NULL;
1763	}
1764out:
1765	if (tmp_blinded)
1766		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1767					   tmp : orig_prog);
1768	return prog;
1769
1770out_free_hdr:
1771	if (header) {
1772		bpf_arch_text_copy(&ro_header->size, &header->size,
1773				   sizeof(header->size));
1774		bpf_jit_binary_pack_free(ro_header, header);
1775	}
1776	goto out_off;
1777}
1778
1779bool bpf_jit_supports_kfunc_call(void)
1780{
1781	return true;
1782}
1783
1784void *bpf_arch_text_copy(void *dst, void *src, size_t len)
1785{
1786	if (!aarch64_insn_copy(dst, src, len))
1787		return ERR_PTR(-EINVAL);
1788	return dst;
1789}
1790
1791u64 bpf_jit_alloc_exec_limit(void)
1792{
1793	return VMALLOC_END - VMALLOC_START;
1794}
1795
1796void *bpf_jit_alloc_exec(unsigned long size)
1797{
1798	/* Memory is intended to be executable, reset the pointer tag. */
1799	return kasan_reset_tag(vmalloc(size));
1800}
1801
1802void bpf_jit_free_exec(void *addr)
1803{
1804	return vfree(addr);
1805}
1806
1807/* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
1808bool bpf_jit_supports_subprog_tailcalls(void)
1809{
1810	return true;
1811}
1812
1813static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
1814			    int args_off, int retval_off, int run_ctx_off,
1815			    bool save_ret)
1816{
1817	__le32 *branch;
1818	u64 enter_prog;
1819	u64 exit_prog;
1820	struct bpf_prog *p = l->link.prog;
1821	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
1822
1823	enter_prog = (u64)bpf_trampoline_enter(p);
1824	exit_prog = (u64)bpf_trampoline_exit(p);
1825
1826	if (l->cookie == 0) {
1827		/* if cookie is zero, one instruction is enough to store it */
1828		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
1829	} else {
1830		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
1831		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
1832		     ctx);
1833	}
1834
1835	/* save p to callee saved register x19 to avoid loading p with mov_i64
1836	 * each time.
1837	 */
1838	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
1839
1840	/* arg1: prog */
1841	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
1842	/* arg2: &run_ctx */
1843	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
1844
1845	emit_call(enter_prog, ctx);
1846
1847	/* if (__bpf_prog_enter(prog) == 0)
1848	 *         goto skip_exec_of_prog;
1849	 */
1850	branch = ctx->image + ctx->idx;
1851	emit(A64_NOP, ctx);
1852
1853	/* save return value to callee saved register x20 */
1854	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
1855
1856	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
1857	if (!p->jited)
1858		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
1859
1860	emit_call((const u64)p->bpf_func, ctx);
1861
1862	if (save_ret)
1863		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
1864
1865	if (ctx->image) {
1866		int offset = &ctx->image[ctx->idx] - branch;
1867		*branch = cpu_to_le32(A64_CBZ(1, A64_R(0), offset));
1868	}
1869
1870	/* arg1: prog */
1871	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
1872	/* arg2: start time */
1873	emit(A64_MOV(1, A64_R(1), A64_R(20)), ctx);
1874	/* arg3: &run_ctx */
1875	emit(A64_ADD_I(1, A64_R(2), A64_SP, run_ctx_off), ctx);
1876
1877	emit_call(exit_prog, ctx);
1878}
1879
1880static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
1881			       int args_off, int retval_off, int run_ctx_off,
1882			       __le32 **branches)
1883{
1884	int i;
1885
1886	/* The first fmod_ret program will receive a garbage return value.
1887	 * Set this to 0 to avoid confusing the program.
1888	 */
1889	emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
1890	for (i = 0; i < tl->nr_links; i++) {
1891		invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
1892				run_ctx_off, true);
1893		/* if (*(u64 *)(sp + retval_off) !=  0)
1894		 *	goto do_fexit;
1895		 */
1896		emit(A64_LDR64I(A64_R(10), A64_SP, retval_off), ctx);
1897		/* Save the location of branch, and generate a nop.
1898		 * This nop will be replaced with a cbnz later.
1899		 */
1900		branches[i] = ctx->image + ctx->idx;
1901		emit(A64_NOP, ctx);
1902	}
1903}
1904
1905static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
1906{
1907	int i;
1908
1909	for (i = 0; i < nregs; i++) {
1910		emit(A64_STR64I(i, A64_SP, args_off), ctx);
1911		args_off += 8;
1912	}
1913}
1914
1915static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
1916{
1917	int i;
1918
1919	for (i = 0; i < nregs; i++) {
1920		emit(A64_LDR64I(i, A64_SP, args_off), ctx);
1921		args_off += 8;
1922	}
1923}
1924
1925/* Based on the x86's implementation of arch_prepare_bpf_trampoline().
1926 *
1927 * bpf prog and function entry before bpf trampoline hooked:
1928 *   mov x9, lr
1929 *   nop
1930 *
1931 * bpf prog and function entry after bpf trampoline hooked:
1932 *   mov x9, lr
1933 *   bl  <bpf_trampoline or plt>
1934 *
1935 */
1936static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
1937			      struct bpf_tramp_links *tlinks, void *func_addr,
1938			      int nregs, u32 flags)
1939{
1940	int i;
1941	int stack_size;
1942	int retaddr_off;
1943	int regs_off;
1944	int retval_off;
1945	int args_off;
1946	int nregs_off;
1947	int ip_off;
1948	int run_ctx_off;
1949	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
1950	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
1951	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
1952	bool save_ret;
1953	__le32 **branches = NULL;
1954
1955	/* trampoline stack layout:
1956	 *                  [ parent ip         ]
1957	 *                  [ FP                ]
1958	 * SP + retaddr_off [ self ip           ]
1959	 *                  [ FP                ]
1960	 *
1961	 *                  [ padding           ] align SP to multiples of 16
1962	 *
1963	 *                  [ x20               ] callee saved reg x20
1964	 * SP + regs_off    [ x19               ] callee saved reg x19
1965	 *
1966	 * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
1967	 *                                        BPF_TRAMP_F_RET_FENTRY_RET
1968	 *
1969	 *                  [ arg reg N         ]
1970	 *                  [ ...               ]
1971	 * SP + args_off    [ arg reg 1         ]
1972	 *
1973	 * SP + nregs_off   [ arg regs count    ]
1974	 *
1975	 * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
1976	 *
1977	 * SP + run_ctx_off [ bpf_tramp_run_ctx ]
1978	 */
1979
1980	stack_size = 0;
1981	run_ctx_off = stack_size;
1982	/* room for bpf_tramp_run_ctx */
1983	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
1984
1985	ip_off = stack_size;
1986	/* room for IP address argument */
1987	if (flags & BPF_TRAMP_F_IP_ARG)
1988		stack_size += 8;
1989
1990	nregs_off = stack_size;
1991	/* room for args count */
1992	stack_size += 8;
1993
1994	args_off = stack_size;
1995	/* room for args */
1996	stack_size += nregs * 8;
1997
1998	/* room for return value */
1999	retval_off = stack_size;
2000	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
2001	if (save_ret)
2002		stack_size += 8;
2003
2004	/* room for callee saved registers, currently x19 and x20 are used */
2005	regs_off = stack_size;
2006	stack_size += 16;
2007
2008	/* round up to multiples of 16 to avoid SPAlignmentFault */
2009	stack_size = round_up(stack_size, 16);
2010
2011	/* return address locates above FP */
2012	retaddr_off = stack_size + 8;
2013
2014	/* bpf trampoline may be invoked by 3 instruction types:
2015	 * 1. bl, attached to bpf prog or kernel function via short jump
2016	 * 2. br, attached to bpf prog or kernel function via long jump
2017	 * 3. blr, working as a function pointer, used by struct_ops.
2018	 * So BTI_JC should used here to support both br and blr.
2019	 */
2020	emit_bti(A64_BTI_JC, ctx);
2021
2022	/* frame for parent function */
2023	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
2024	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2025
2026	/* frame for patched function */
2027	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
2028	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2029
2030	/* allocate stack space */
2031	emit(A64_SUB_I(1, A64_SP, A64_SP, stack_size), ctx);
2032
2033	if (flags & BPF_TRAMP_F_IP_ARG) {
2034		/* save ip address of the traced function */
2035		emit_addr_mov_i64(A64_R(10), (const u64)func_addr, ctx);
2036		emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
2037	}
2038
2039	/* save arg regs count*/
2040	emit(A64_MOVZ(1, A64_R(10), nregs, 0), ctx);
2041	emit(A64_STR64I(A64_R(10), A64_SP, nregs_off), ctx);
2042
2043	/* save arg regs */
2044	save_args(ctx, args_off, nregs);
2045
2046	/* save callee saved registers */
2047	emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
2048	emit(A64_STR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2049
2050	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2051		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
2052		emit_call((const u64)__bpf_tramp_enter, ctx);
2053	}
2054
2055	for (i = 0; i < fentry->nr_links; i++)
2056		invoke_bpf_prog(ctx, fentry->links[i], args_off,
2057				retval_off, run_ctx_off,
2058				flags & BPF_TRAMP_F_RET_FENTRY_RET);
2059
2060	if (fmod_ret->nr_links) {
2061		branches = kcalloc(fmod_ret->nr_links, sizeof(__le32 *),
2062				   GFP_KERNEL);
2063		if (!branches)
2064			return -ENOMEM;
2065
2066		invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
2067				   run_ctx_off, branches);
2068	}
2069
2070	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2071		restore_args(ctx, args_off, nregs);
2072		/* call original func */
2073		emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
2074		emit(A64_ADR(A64_LR, AARCH64_INSN_SIZE * 2), ctx);
2075		emit(A64_RET(A64_R(10)), ctx);
2076		/* store return value */
2077		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
2078		/* reserve a nop for bpf_tramp_image_put */
2079		im->ip_after_call = ctx->ro_image + ctx->idx;
2080		emit(A64_NOP, ctx);
2081	}
2082
2083	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
2084	for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
2085		int offset = &ctx->image[ctx->idx] - branches[i];
2086		*branches[i] = cpu_to_le32(A64_CBNZ(1, A64_R(10), offset));
2087	}
2088
2089	for (i = 0; i < fexit->nr_links; i++)
2090		invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
2091				run_ctx_off, false);
2092
2093	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2094		im->ip_epilogue = ctx->ro_image + ctx->idx;
2095		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
2096		emit_call((const u64)__bpf_tramp_exit, ctx);
2097	}
2098
2099	if (flags & BPF_TRAMP_F_RESTORE_REGS)
2100		restore_args(ctx, args_off, nregs);
2101
2102	/* restore callee saved register x19 and x20 */
2103	emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
2104	emit(A64_LDR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2105
2106	if (save_ret)
2107		emit(A64_LDR64I(A64_R(0), A64_SP, retval_off), ctx);
2108
2109	/* reset SP  */
2110	emit(A64_MOV(1, A64_SP, A64_FP), ctx);
2111
2112	/* pop frames  */
2113	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
2114	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
2115
2116	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2117		/* skip patched function, return to parent */
2118		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2119		emit(A64_RET(A64_R(9)), ctx);
2120	} else {
2121		/* return to patched function */
2122		emit(A64_MOV(1, A64_R(10), A64_LR), ctx);
2123		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2124		emit(A64_RET(A64_R(10)), ctx);
2125	}
2126
2127	kfree(branches);
2128
2129	return ctx->idx;
2130}
2131
2132static int btf_func_model_nregs(const struct btf_func_model *m)
2133{
2134	int nregs = m->nr_args;
2135	int i;
2136
2137	/* extra registers needed for struct argument */
2138	for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
2139		/* The arg_size is at most 16 bytes, enforced by the verifier. */
2140		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2141			nregs += (m->arg_size[i] + 7) / 8 - 1;
2142	}
2143
2144	return nregs;
2145}
2146
2147int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
2148			     struct bpf_tramp_links *tlinks, void *func_addr)
2149{
2150	struct jit_ctx ctx = {
2151		.image = NULL,
2152		.idx = 0,
2153	};
2154	struct bpf_tramp_image im;
2155	int nregs, ret;
2156
2157	nregs = btf_func_model_nregs(m);
2158	/* the first 8 registers are used for arguments */
2159	if (nregs > 8)
2160		return -ENOTSUPP;
2161
2162	ret = prepare_trampoline(&ctx, &im, tlinks, func_addr, nregs, flags);
2163	if (ret < 0)
2164		return ret;
2165
2166	return ret < 0 ? ret : ret * AARCH64_INSN_SIZE;
2167}
2168
2169void *arch_alloc_bpf_trampoline(unsigned int size)
2170{
2171	return bpf_prog_pack_alloc(size, jit_fill_hole);
2172}
2173
2174void arch_free_bpf_trampoline(void *image, unsigned int size)
2175{
2176	bpf_prog_pack_free(image, size);
2177}
2178
2179void arch_protect_bpf_trampoline(void *image, unsigned int size)
2180{
2181}
2182
2183void arch_unprotect_bpf_trampoline(void *image, unsigned int size)
2184{
2185}
2186
2187int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image,
2188				void *ro_image_end, const struct btf_func_model *m,
2189				u32 flags, struct bpf_tramp_links *tlinks,
2190				void *func_addr)
2191{
2192	int ret, nregs;
2193	void *image, *tmp;
2194	u32 size = ro_image_end - ro_image;
2195
2196	/* image doesn't need to be in module memory range, so we can
2197	 * use kvmalloc.
2198	 */
2199	image = kvmalloc(size, GFP_KERNEL);
2200	if (!image)
2201		return -ENOMEM;
2202
2203	struct jit_ctx ctx = {
2204		.image = image,
2205		.ro_image = ro_image,
2206		.idx = 0,
2207	};
2208
2209	nregs = btf_func_model_nregs(m);
2210	/* the first 8 registers are used for arguments */
2211	if (nregs > 8)
2212		return -ENOTSUPP;
2213
2214	jit_fill_hole(image, (unsigned int)(ro_image_end - ro_image));
2215	ret = prepare_trampoline(&ctx, im, tlinks, func_addr, nregs, flags);
2216
2217	if (ret > 0 && validate_code(&ctx) < 0) {
2218		ret = -EINVAL;
2219		goto out;
2220	}
2221
2222	if (ret > 0)
2223		ret *= AARCH64_INSN_SIZE;
2224
2225	tmp = bpf_arch_text_copy(ro_image, image, size);
2226	if (IS_ERR(tmp)) {
2227		ret = PTR_ERR(tmp);
2228		goto out;
2229	}
2230
2231	bpf_flush_icache(ro_image, ro_image + size);
2232out:
2233	kvfree(image);
2234	return ret;
2235}
2236
2237static bool is_long_jump(void *ip, void *target)
2238{
2239	long offset;
2240
2241	/* NULL target means this is a NOP */
2242	if (!target)
2243		return false;
2244
2245	offset = (long)target - (long)ip;
2246	return offset < -SZ_128M || offset >= SZ_128M;
2247}
2248
2249static int gen_branch_or_nop(enum aarch64_insn_branch_type type, void *ip,
2250			     void *addr, void *plt, u32 *insn)
2251{
2252	void *target;
2253
2254	if (!addr) {
2255		*insn = aarch64_insn_gen_nop();
2256		return 0;
2257	}
2258
2259	if (is_long_jump(ip, addr))
2260		target = plt;
2261	else
2262		target = addr;
2263
2264	*insn = aarch64_insn_gen_branch_imm((unsigned long)ip,
2265					    (unsigned long)target,
2266					    type);
2267
2268	return *insn != AARCH64_BREAK_FAULT ? 0 : -EFAULT;
2269}
2270
2271/* Replace the branch instruction from @ip to @old_addr in a bpf prog or a bpf
2272 * trampoline with the branch instruction from @ip to @new_addr. If @old_addr
2273 * or @new_addr is NULL, the old or new instruction is NOP.
2274 *
2275 * When @ip is the bpf prog entry, a bpf trampoline is being attached or
2276 * detached. Since bpf trampoline and bpf prog are allocated separately with
2277 * vmalloc, the address distance may exceed 128MB, the maximum branch range.
2278 * So long jump should be handled.
2279 *
2280 * When a bpf prog is constructed, a plt pointing to empty trampoline
2281 * dummy_tramp is placed at the end:
2282 *
2283 *      bpf_prog:
2284 *              mov x9, lr
2285 *              nop // patchsite
2286 *              ...
2287 *              ret
2288 *
2289 *      plt:
2290 *              ldr x10, target
2291 *              br x10
2292 *      target:
2293 *              .quad dummy_tramp // plt target
2294 *
2295 * This is also the state when no trampoline is attached.
2296 *
2297 * When a short-jump bpf trampoline is attached, the patchsite is patched
2298 * to a bl instruction to the trampoline directly:
2299 *
2300 *      bpf_prog:
2301 *              mov x9, lr
2302 *              bl <short-jump bpf trampoline address> // patchsite
2303 *              ...
2304 *              ret
2305 *
2306 *      plt:
2307 *              ldr x10, target
2308 *              br x10
2309 *      target:
2310 *              .quad dummy_tramp // plt target
2311 *
2312 * When a long-jump bpf trampoline is attached, the plt target is filled with
2313 * the trampoline address and the patchsite is patched to a bl instruction to
2314 * the plt:
2315 *
2316 *      bpf_prog:
2317 *              mov x9, lr
2318 *              bl plt // patchsite
2319 *              ...
2320 *              ret
2321 *
2322 *      plt:
2323 *              ldr x10, target
2324 *              br x10
2325 *      target:
2326 *              .quad <long-jump bpf trampoline address> // plt target
2327 *
2328 * The dummy_tramp is used to prevent another CPU from jumping to unknown
2329 * locations during the patching process, making the patching process easier.
2330 */
2331int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
2332		       void *old_addr, void *new_addr)
2333{
2334	int ret;
2335	u32 old_insn;
2336	u32 new_insn;
2337	u32 replaced;
2338	struct bpf_plt *plt = NULL;
2339	unsigned long size = 0UL;
2340	unsigned long offset = ~0UL;
2341	enum aarch64_insn_branch_type branch_type;
2342	char namebuf[KSYM_NAME_LEN];
2343	void *image = NULL;
2344	u64 plt_target = 0ULL;
2345	bool poking_bpf_entry;
2346
2347	if (!__bpf_address_lookup((unsigned long)ip, &size, &offset, namebuf))
2348		/* Only poking bpf text is supported. Since kernel function
2349		 * entry is set up by ftrace, we reply on ftrace to poke kernel
2350		 * functions.
2351		 */
2352		return -ENOTSUPP;
2353
2354	image = ip - offset;
2355	/* zero offset means we're poking bpf prog entry */
2356	poking_bpf_entry = (offset == 0UL);
2357
2358	/* bpf prog entry, find plt and the real patchsite */
2359	if (poking_bpf_entry) {
2360		/* plt locates at the end of bpf prog */
2361		plt = image + size - PLT_TARGET_OFFSET;
2362
2363		/* skip to the nop instruction in bpf prog entry:
2364		 * bti c // if BTI enabled
2365		 * mov x9, x30
2366		 * nop
2367		 */
2368		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;
2369	}
2370
2371	/* long jump is only possible at bpf prog entry */
2372	if (WARN_ON((is_long_jump(ip, new_addr) || is_long_jump(ip, old_addr)) &&
2373		    !poking_bpf_entry))
2374		return -EINVAL;
2375
2376	if (poke_type == BPF_MOD_CALL)
2377		branch_type = AARCH64_INSN_BRANCH_LINK;
2378	else
2379		branch_type = AARCH64_INSN_BRANCH_NOLINK;
2380
2381	if (gen_branch_or_nop(branch_type, ip, old_addr, plt, &old_insn) < 0)
2382		return -EFAULT;
2383
2384	if (gen_branch_or_nop(branch_type, ip, new_addr, plt, &new_insn) < 0)
2385		return -EFAULT;
2386
2387	if (is_long_jump(ip, new_addr))
2388		plt_target = (u64)new_addr;
2389	else if (is_long_jump(ip, old_addr))
2390		/* if the old target is a long jump and the new target is not,
2391		 * restore the plt target to dummy_tramp, so there is always a
2392		 * legal and harmless address stored in plt target, and we'll
2393		 * never jump from plt to an unknown place.
2394		 */
2395		plt_target = (u64)&dummy_tramp;
2396
2397	if (plt_target) {
2398		/* non-zero plt_target indicates we're patching a bpf prog,
2399		 * which is read only.
2400		 */
2401		if (set_memory_rw(PAGE_MASK & ((uintptr_t)&plt->target), 1))
2402			return -EFAULT;
2403		WRITE_ONCE(plt->target, plt_target);
2404		set_memory_ro(PAGE_MASK & ((uintptr_t)&plt->target), 1);
2405		/* since plt target points to either the new trampoline
2406		 * or dummy_tramp, even if another CPU reads the old plt
2407		 * target value before fetching the bl instruction to plt,
2408		 * it will be brought back by dummy_tramp, so no barrier is
2409		 * required here.
2410		 */
2411	}
2412
2413	/* if the old target and the new target are both long jumps, no
2414	 * patching is required
2415	 */
2416	if (old_insn == new_insn)
2417		return 0;
2418
2419	mutex_lock(&text_mutex);
2420	if (aarch64_insn_read(ip, &replaced)) {
2421		ret = -EFAULT;
2422		goto out;
2423	}
2424
2425	if (replaced != old_insn) {
2426		ret = -EFAULT;
2427		goto out;
2428	}
2429
2430	/* We call aarch64_insn_patch_text_nosync() to replace instruction
2431	 * atomically, so no other CPUs will fetch a half-new and half-old
2432	 * instruction. But there is chance that another CPU executes the
2433	 * old instruction after the patching operation finishes (e.g.,
2434	 * pipeline not flushed, or icache not synchronized yet).
2435	 *
2436	 * 1. when a new trampoline is attached, it is not a problem for
2437	 *    different CPUs to jump to different trampolines temporarily.
2438	 *
2439	 * 2. when an old trampoline is freed, we should wait for all other
2440	 *    CPUs to exit the trampoline and make sure the trampoline is no
2441	 *    longer reachable, since bpf_tramp_image_put() function already
2442	 *    uses percpu_ref and task-based rcu to do the sync, no need to call
2443	 *    the sync version here, see bpf_tramp_image_put() for details.
2444	 */
2445	ret = aarch64_insn_patch_text_nosync(ip, new_insn);
2446out:
2447	mutex_unlock(&text_mutex);
2448
2449	return ret;
2450}
2451
2452bool bpf_jit_supports_ptr_xchg(void)
2453{
2454	return true;
2455}
2456
2457bool bpf_jit_supports_exceptions(void)
2458{
2459	/* We unwind through both kernel frames starting from within bpf_throw
2460	 * call and BPF frames. Therefore we require FP unwinder to be enabled
2461	 * to walk kernel frames and reach BPF frames in the stack trace.
2462	 * ARM64 kernel is aways compiled with CONFIG_FRAME_POINTER=y
2463	 */
2464	return true;
2465}
2466
2467void bpf_jit_free(struct bpf_prog *prog)
2468{
2469	if (prog->jited) {
2470		struct arm64_jit_data *jit_data = prog->aux->jit_data;
2471		struct bpf_binary_header *hdr;
2472
2473		/*
2474		 * If we fail the final pass of JIT (from jit_subprogs),
2475		 * the program may not be finalized yet. Call finalize here
2476		 * before freeing it.
2477		 */
2478		if (jit_data) {
2479			bpf_arch_text_copy(&jit_data->ro_header->size, &jit_data->header->size,
2480					   sizeof(jit_data->header->size));
2481			kfree(jit_data);
2482		}
2483		hdr = bpf_jit_binary_pack_hdr(prog);
2484		bpf_jit_binary_pack_free(hdr, NULL);
2485		WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
2486	}
2487
2488	bpf_prog_unlock_free(prog);
2489}
2490