1// SPDX-License-Identifier: GPL-2.0
2/*
3 * BPF JIT compiler for RV32G
4 *
5 * Copyright (c) 2020 Luke Nelson <luke.r.nels@gmail.com>
6 * Copyright (c) 2020 Xi Wang <xi.wang@gmail.com>
7 *
8 * The code is based on the BPF JIT compiler for RV64G by Bj��rn T��pel and
9 * the BPF JIT compiler for 32-bit ARM by Shubham Bansal and Mircea Gherzan.
10 */
11
12#include <linux/bpf.h>
13#include <linux/filter.h>
14#include "bpf_jit.h"
15
16/*
17 * Stack layout during BPF program execution:
18 *
19 *                     high
20 *     RV32 fp =>  +----------+
21 *                 | saved ra |
22 *                 | saved fp | RV32 callee-saved registers
23 *                 |   ...    |
24 *                 +----------+ <= (fp - 4 * NR_SAVED_REGISTERS)
25 *                 |  hi(R6)  |
26 *                 |  lo(R6)  |
27 *                 |  hi(R7)  | JIT scratch space for BPF registers
28 *                 |  lo(R7)  |
29 *                 |   ...    |
30 *  BPF_REG_FP =>  +----------+ <= (fp - 4 * NR_SAVED_REGISTERS
31 *                 |          |        - 4 * BPF_JIT_SCRATCH_REGS)
32 *                 |          |
33 *                 |   ...    | BPF program stack
34 *                 |          |
35 *     RV32 sp =>  +----------+
36 *                 |          |
37 *                 |   ...    | Function call stack
38 *                 |          |
39 *                 +----------+
40 *                     low
41 */
42
43enum {
44	/* Stack layout - these are offsets from top of JIT scratch space. */
45	BPF_R6_HI,
46	BPF_R6_LO,
47	BPF_R7_HI,
48	BPF_R7_LO,
49	BPF_R8_HI,
50	BPF_R8_LO,
51	BPF_R9_HI,
52	BPF_R9_LO,
53	BPF_AX_HI,
54	BPF_AX_LO,
55	/* Stack space for BPF_REG_6 through BPF_REG_9 and BPF_REG_AX. */
56	BPF_JIT_SCRATCH_REGS,
57};
58
59/* Number of callee-saved registers stored to stack: ra, fp, s1--s7. */
60#define NR_SAVED_REGISTERS	9
61
62/* Offset from fp for BPF registers stored on stack. */
63#define STACK_OFFSET(k)	(-4 - (4 * NR_SAVED_REGISTERS) - (4 * (k)))
64
65#define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
66#define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
67
68#define RV_REG_TCC		RV_REG_T6
69#define RV_REG_TCC_SAVED	RV_REG_S7
70
71static const s8 bpf2rv32[][2] = {
72	/* Return value from in-kernel function, and exit value from eBPF. */
73	[BPF_REG_0] = {RV_REG_S2, RV_REG_S1},
74	/* Arguments from eBPF program to in-kernel function. */
75	[BPF_REG_1] = {RV_REG_A1, RV_REG_A0},
76	[BPF_REG_2] = {RV_REG_A3, RV_REG_A2},
77	[BPF_REG_3] = {RV_REG_A5, RV_REG_A4},
78	[BPF_REG_4] = {RV_REG_A7, RV_REG_A6},
79	[BPF_REG_5] = {RV_REG_S4, RV_REG_S3},
80	/*
81	 * Callee-saved registers that in-kernel function will preserve.
82	 * Stored on the stack.
83	 */
84	[BPF_REG_6] = {STACK_OFFSET(BPF_R6_HI), STACK_OFFSET(BPF_R6_LO)},
85	[BPF_REG_7] = {STACK_OFFSET(BPF_R7_HI), STACK_OFFSET(BPF_R7_LO)},
86	[BPF_REG_8] = {STACK_OFFSET(BPF_R8_HI), STACK_OFFSET(BPF_R8_LO)},
87	[BPF_REG_9] = {STACK_OFFSET(BPF_R9_HI), STACK_OFFSET(BPF_R9_LO)},
88	/* Read-only frame pointer to access BPF stack. */
89	[BPF_REG_FP] = {RV_REG_S6, RV_REG_S5},
90	/* Temporary register for blinding constants. Stored on the stack. */
91	[BPF_REG_AX] = {STACK_OFFSET(BPF_AX_HI), STACK_OFFSET(BPF_AX_LO)},
92	/*
93	 * Temporary registers used by the JIT to operate on registers stored
94	 * on the stack. Save t0 and t1 to be used as temporaries in generated
95	 * code.
96	 */
97	[TMP_REG_1] = {RV_REG_T3, RV_REG_T2},
98	[TMP_REG_2] = {RV_REG_T5, RV_REG_T4},
99};
100
101static s8 hi(const s8 *r)
102{
103	return r[0];
104}
105
106static s8 lo(const s8 *r)
107{
108	return r[1];
109}
110
111static void emit_imm(const s8 rd, s32 imm, struct rv_jit_context *ctx)
112{
113	u32 upper = (imm + (1 << 11)) >> 12;
114	u32 lower = imm & 0xfff;
115
116	if (upper) {
117		emit(rv_lui(rd, upper), ctx);
118		emit(rv_addi(rd, rd, lower), ctx);
119	} else {
120		emit(rv_addi(rd, RV_REG_ZERO, lower), ctx);
121	}
122}
123
124static void emit_imm32(const s8 *rd, s32 imm, struct rv_jit_context *ctx)
125{
126	/* Emit immediate into lower bits. */
127	emit_imm(lo(rd), imm, ctx);
128
129	/* Sign-extend into upper bits. */
130	if (imm >= 0)
131		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
132	else
133		emit(rv_addi(hi(rd), RV_REG_ZERO, -1), ctx);
134}
135
136static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
137		       struct rv_jit_context *ctx)
138{
139	emit_imm(lo(rd), imm_lo, ctx);
140	emit_imm(hi(rd), imm_hi, ctx);
141}
142
143static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
144{
145	int stack_adjust = ctx->stack_size;
146	const s8 *r0 = bpf2rv32[BPF_REG_0];
147
148	/* Set return value if not tail call. */
149	if (!is_tail_call) {
150		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
151		emit(rv_addi(RV_REG_A1, hi(r0), 0), ctx);
152	}
153
154	/* Restore callee-saved registers. */
155	emit(rv_lw(RV_REG_RA, stack_adjust - 4, RV_REG_SP), ctx);
156	emit(rv_lw(RV_REG_FP, stack_adjust - 8, RV_REG_SP), ctx);
157	emit(rv_lw(RV_REG_S1, stack_adjust - 12, RV_REG_SP), ctx);
158	emit(rv_lw(RV_REG_S2, stack_adjust - 16, RV_REG_SP), ctx);
159	emit(rv_lw(RV_REG_S3, stack_adjust - 20, RV_REG_SP), ctx);
160	emit(rv_lw(RV_REG_S4, stack_adjust - 24, RV_REG_SP), ctx);
161	emit(rv_lw(RV_REG_S5, stack_adjust - 28, RV_REG_SP), ctx);
162	emit(rv_lw(RV_REG_S6, stack_adjust - 32, RV_REG_SP), ctx);
163	emit(rv_lw(RV_REG_S7, stack_adjust - 36, RV_REG_SP), ctx);
164
165	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
166
167	if (is_tail_call) {
168		/*
169		 * goto *(t0 + 4);
170		 * Skips first instruction of prologue which initializes tail
171		 * call counter. Assumes t0 contains address of target program,
172		 * see emit_bpf_tail_call.
173		 */
174		emit(rv_jalr(RV_REG_ZERO, RV_REG_T0, 4), ctx);
175	} else {
176		emit(rv_jalr(RV_REG_ZERO, RV_REG_RA, 0), ctx);
177	}
178}
179
180static bool is_stacked(s8 reg)
181{
182	return reg < 0;
183}
184
185static const s8 *bpf_get_reg64(const s8 *reg, const s8 *tmp,
186			       struct rv_jit_context *ctx)
187{
188	if (is_stacked(hi(reg))) {
189		emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
190		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
191		reg = tmp;
192	}
193	return reg;
194}
195
196static void bpf_put_reg64(const s8 *reg, const s8 *src,
197			  struct rv_jit_context *ctx)
198{
199	if (is_stacked(hi(reg))) {
200		emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
201		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
202	}
203}
204
205static const s8 *bpf_get_reg32(const s8 *reg, const s8 *tmp,
206			       struct rv_jit_context *ctx)
207{
208	if (is_stacked(lo(reg))) {
209		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
210		reg = tmp;
211	}
212	return reg;
213}
214
215static void bpf_put_reg32(const s8 *reg, const s8 *src,
216			  struct rv_jit_context *ctx)
217{
218	if (is_stacked(lo(reg))) {
219		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
220		if (!ctx->prog->aux->verifier_zext)
221			emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
222	} else if (!ctx->prog->aux->verifier_zext) {
223		emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
224	}
225}
226
227static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
228			       struct rv_jit_context *ctx)
229{
230	s32 upper, lower;
231
232	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
233		emit(rv_jal(rd, rvoff >> 1), ctx);
234		return;
235	}
236
237	upper = (rvoff + (1 << 11)) >> 12;
238	lower = rvoff & 0xfff;
239	emit(rv_auipc(RV_REG_T1, upper), ctx);
240	emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
241}
242
243static void emit_alu_i64(const s8 *dst, s32 imm,
244			 struct rv_jit_context *ctx, const u8 op)
245{
246	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
247	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
248
249	switch (op) {
250	case BPF_MOV:
251		emit_imm32(rd, imm, ctx);
252		break;
253	case BPF_AND:
254		if (is_12b_int(imm)) {
255			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
256		} else {
257			emit_imm(RV_REG_T0, imm, ctx);
258			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
259		}
260		if (imm >= 0)
261			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
262		break;
263	case BPF_OR:
264		if (is_12b_int(imm)) {
265			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
266		} else {
267			emit_imm(RV_REG_T0, imm, ctx);
268			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
269		}
270		if (imm < 0)
271			emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
272		break;
273	case BPF_XOR:
274		if (is_12b_int(imm)) {
275			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
276		} else {
277			emit_imm(RV_REG_T0, imm, ctx);
278			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
279		}
280		if (imm < 0)
281			emit(rv_xori(hi(rd), hi(rd), -1), ctx);
282		break;
283	case BPF_LSH:
284		if (imm >= 32) {
285			emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
286			emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
287		} else if (imm == 0) {
288			/* Do nothing. */
289		} else {
290			emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
291			emit(rv_slli(hi(rd), hi(rd), imm), ctx);
292			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
293			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
294		}
295		break;
296	case BPF_RSH:
297		if (imm >= 32) {
298			emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
299			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
300		} else if (imm == 0) {
301			/* Do nothing. */
302		} else {
303			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
304			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
305			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
306			emit(rv_srli(hi(rd), hi(rd), imm), ctx);
307		}
308		break;
309	case BPF_ARSH:
310		if (imm >= 32) {
311			emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
312			emit(rv_srai(hi(rd), hi(rd), 31), ctx);
313		} else if (imm == 0) {
314			/* Do nothing. */
315		} else {
316			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
317			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
318			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
319			emit(rv_srai(hi(rd), hi(rd), imm), ctx);
320		}
321		break;
322	}
323
324	bpf_put_reg64(dst, rd, ctx);
325}
326
327static void emit_alu_i32(const s8 *dst, s32 imm,
328			 struct rv_jit_context *ctx, const u8 op)
329{
330	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
331	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
332
333	switch (op) {
334	case BPF_MOV:
335		emit_imm(lo(rd), imm, ctx);
336		break;
337	case BPF_ADD:
338		if (is_12b_int(imm)) {
339			emit(rv_addi(lo(rd), lo(rd), imm), ctx);
340		} else {
341			emit_imm(RV_REG_T0, imm, ctx);
342			emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
343		}
344		break;
345	case BPF_SUB:
346		if (is_12b_int(-imm)) {
347			emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
348		} else {
349			emit_imm(RV_REG_T0, imm, ctx);
350			emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
351		}
352		break;
353	case BPF_AND:
354		if (is_12b_int(imm)) {
355			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
356		} else {
357			emit_imm(RV_REG_T0, imm, ctx);
358			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
359		}
360		break;
361	case BPF_OR:
362		if (is_12b_int(imm)) {
363			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
364		} else {
365			emit_imm(RV_REG_T0, imm, ctx);
366			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
367		}
368		break;
369	case BPF_XOR:
370		if (is_12b_int(imm)) {
371			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
372		} else {
373			emit_imm(RV_REG_T0, imm, ctx);
374			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
375		}
376		break;
377	case BPF_LSH:
378		if (is_12b_int(imm)) {
379			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
380		} else {
381			emit_imm(RV_REG_T0, imm, ctx);
382			emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
383		}
384		break;
385	case BPF_RSH:
386		if (is_12b_int(imm)) {
387			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
388		} else {
389			emit_imm(RV_REG_T0, imm, ctx);
390			emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
391		}
392		break;
393	case BPF_ARSH:
394		if (is_12b_int(imm)) {
395			emit(rv_srai(lo(rd), lo(rd), imm), ctx);
396		} else {
397			emit_imm(RV_REG_T0, imm, ctx);
398			emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
399		}
400		break;
401	}
402
403	bpf_put_reg32(dst, rd, ctx);
404}
405
406static void emit_alu_r64(const s8 *dst, const s8 *src,
407			 struct rv_jit_context *ctx, const u8 op)
408{
409	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
410	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
411	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
412	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
413
414	switch (op) {
415	case BPF_MOV:
416		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
417		emit(rv_addi(hi(rd), hi(rs), 0), ctx);
418		break;
419	case BPF_ADD:
420		if (rd == rs) {
421			emit(rv_srli(RV_REG_T0, lo(rd), 31), ctx);
422			emit(rv_slli(hi(rd), hi(rd), 1), ctx);
423			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
424			emit(rv_slli(lo(rd), lo(rd), 1), ctx);
425		} else {
426			emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
427			emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
428			emit(rv_add(hi(rd), hi(rd), hi(rs)), ctx);
429			emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
430		}
431		break;
432	case BPF_SUB:
433		emit(rv_sub(RV_REG_T1, hi(rd), hi(rs)), ctx);
434		emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
435		emit(rv_sub(hi(rd), RV_REG_T1, RV_REG_T0), ctx);
436		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
437		break;
438	case BPF_AND:
439		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
440		emit(rv_and(hi(rd), hi(rd), hi(rs)), ctx);
441		break;
442	case BPF_OR:
443		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
444		emit(rv_or(hi(rd), hi(rd), hi(rs)), ctx);
445		break;
446	case BPF_XOR:
447		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
448		emit(rv_xor(hi(rd), hi(rd), hi(rs)), ctx);
449		break;
450	case BPF_MUL:
451		emit(rv_mul(RV_REG_T0, hi(rs), lo(rd)), ctx);
452		emit(rv_mul(hi(rd), hi(rd), lo(rs)), ctx);
453		emit(rv_mulhu(RV_REG_T1, lo(rd), lo(rs)), ctx);
454		emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
455		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
456		emit(rv_add(hi(rd), hi(rd), RV_REG_T1), ctx);
457		break;
458	case BPF_LSH:
459		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
460		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
461		emit(rv_sll(hi(rd), lo(rd), RV_REG_T0), ctx);
462		emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
463		emit(rv_jal(RV_REG_ZERO, 16), ctx);
464		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
465		emit(rv_srli(RV_REG_T0, lo(rd), 1), ctx);
466		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
467		emit(rv_srl(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
468		emit(rv_sll(hi(rd), hi(rd), lo(rs)), ctx);
469		emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
470		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
471		break;
472	case BPF_RSH:
473		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
474		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
475		emit(rv_srl(lo(rd), hi(rd), RV_REG_T0), ctx);
476		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
477		emit(rv_jal(RV_REG_ZERO, 16), ctx);
478		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
479		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
480		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
481		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
482		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
483		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
484		emit(rv_srl(hi(rd), hi(rd), lo(rs)), ctx);
485		break;
486	case BPF_ARSH:
487		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
488		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
489		emit(rv_sra(lo(rd), hi(rd), RV_REG_T0), ctx);
490		emit(rv_srai(hi(rd), hi(rd), 31), ctx);
491		emit(rv_jal(RV_REG_ZERO, 16), ctx);
492		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
493		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
494		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
495		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
496		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
497		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
498		emit(rv_sra(hi(rd), hi(rd), lo(rs)), ctx);
499		break;
500	case BPF_NEG:
501		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
502		emit(rv_sltu(RV_REG_T0, RV_REG_ZERO, lo(rd)), ctx);
503		emit(rv_sub(hi(rd), RV_REG_ZERO, hi(rd)), ctx);
504		emit(rv_sub(hi(rd), hi(rd), RV_REG_T0), ctx);
505		break;
506	}
507
508	bpf_put_reg64(dst, rd, ctx);
509}
510
511static void emit_alu_r32(const s8 *dst, const s8 *src,
512			 struct rv_jit_context *ctx, const u8 op)
513{
514	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
515	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
516	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
517	const s8 *rs = bpf_get_reg32(src, tmp2, ctx);
518
519	switch (op) {
520	case BPF_MOV:
521		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
522		break;
523	case BPF_ADD:
524		emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
525		break;
526	case BPF_SUB:
527		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
528		break;
529	case BPF_AND:
530		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
531		break;
532	case BPF_OR:
533		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
534		break;
535	case BPF_XOR:
536		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
537		break;
538	case BPF_MUL:
539		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
540		break;
541	case BPF_DIV:
542		emit(rv_divu(lo(rd), lo(rd), lo(rs)), ctx);
543		break;
544	case BPF_MOD:
545		emit(rv_remu(lo(rd), lo(rd), lo(rs)), ctx);
546		break;
547	case BPF_LSH:
548		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
549		break;
550	case BPF_RSH:
551		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
552		break;
553	case BPF_ARSH:
554		emit(rv_sra(lo(rd), lo(rd), lo(rs)), ctx);
555		break;
556	case BPF_NEG:
557		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
558		break;
559	}
560
561	bpf_put_reg32(dst, rd, ctx);
562}
563
564static int emit_branch_r64(const s8 *src1, const s8 *src2, s32 rvoff,
565			   struct rv_jit_context *ctx, const u8 op)
566{
567	int e, s = ctx->ninsns;
568	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
569	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
570
571	const s8 *rs1 = bpf_get_reg64(src1, tmp1, ctx);
572	const s8 *rs2 = bpf_get_reg64(src2, tmp2, ctx);
573
574	/*
575	 * NO_JUMP skips over the rest of the instructions and the
576	 * emit_jump_and_link, meaning the BPF branch is not taken.
577	 * JUMP skips directly to the emit_jump_and_link, meaning
578	 * the BPF branch is taken.
579	 *
580	 * The fallthrough case results in the BPF branch being taken.
581	 */
582#define NO_JUMP(idx) (6 + (2 * (idx)))
583#define JUMP(idx) (2 + (2 * (idx)))
584
585	switch (op) {
586	case BPF_JEQ:
587		emit(rv_bne(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
588		emit(rv_bne(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
589		break;
590	case BPF_JGT:
591		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
592		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
593		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
594		break;
595	case BPF_JLT:
596		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
597		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
598		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
599		break;
600	case BPF_JGE:
601		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
602		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
603		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
604		break;
605	case BPF_JLE:
606		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
607		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
608		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
609		break;
610	case BPF_JNE:
611		emit(rv_bne(hi(rs1), hi(rs2), JUMP(1)), ctx);
612		emit(rv_beq(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
613		break;
614	case BPF_JSGT:
615		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
616		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
617		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
618		break;
619	case BPF_JSLT:
620		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
621		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
622		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
623		break;
624	case BPF_JSGE:
625		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
626		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
627		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
628		break;
629	case BPF_JSLE:
630		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
631		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
632		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
633		break;
634	case BPF_JSET:
635		emit(rv_and(RV_REG_T0, hi(rs1), hi(rs2)), ctx);
636		emit(rv_bne(RV_REG_T0, RV_REG_ZERO, JUMP(2)), ctx);
637		emit(rv_and(RV_REG_T0, lo(rs1), lo(rs2)), ctx);
638		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, NO_JUMP(0)), ctx);
639		break;
640	}
641
642#undef NO_JUMP
643#undef JUMP
644
645	e = ctx->ninsns;
646	/* Adjust for extra insns. */
647	rvoff -= ninsns_rvoff(e - s);
648	emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
649	return 0;
650}
651
652static int emit_bcc(u8 op, u8 rd, u8 rs, int rvoff, struct rv_jit_context *ctx)
653{
654	int e, s = ctx->ninsns;
655	bool far = false;
656	int off;
657
658	if (op == BPF_JSET) {
659		/*
660		 * BPF_JSET is a special case: it has no inverse so we always
661		 * treat it as a far branch.
662		 */
663		far = true;
664	} else if (!is_13b_int(rvoff)) {
665		op = invert_bpf_cond(op);
666		far = true;
667	}
668
669	/*
670	 * For a far branch, the condition is negated and we jump over the
671	 * branch itself, and the two instructions from emit_jump_and_link.
672	 * For a near branch, just use rvoff.
673	 */
674	off = far ? 6 : (rvoff >> 1);
675
676	switch (op) {
677	case BPF_JEQ:
678		emit(rv_beq(rd, rs, off), ctx);
679		break;
680	case BPF_JGT:
681		emit(rv_bgtu(rd, rs, off), ctx);
682		break;
683	case BPF_JLT:
684		emit(rv_bltu(rd, rs, off), ctx);
685		break;
686	case BPF_JGE:
687		emit(rv_bgeu(rd, rs, off), ctx);
688		break;
689	case BPF_JLE:
690		emit(rv_bleu(rd, rs, off), ctx);
691		break;
692	case BPF_JNE:
693		emit(rv_bne(rd, rs, off), ctx);
694		break;
695	case BPF_JSGT:
696		emit(rv_bgt(rd, rs, off), ctx);
697		break;
698	case BPF_JSLT:
699		emit(rv_blt(rd, rs, off), ctx);
700		break;
701	case BPF_JSGE:
702		emit(rv_bge(rd, rs, off), ctx);
703		break;
704	case BPF_JSLE:
705		emit(rv_ble(rd, rs, off), ctx);
706		break;
707	case BPF_JSET:
708		emit(rv_and(RV_REG_T0, rd, rs), ctx);
709		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, off), ctx);
710		break;
711	}
712
713	if (far) {
714		e = ctx->ninsns;
715		/* Adjust for extra insns. */
716		rvoff -= ninsns_rvoff(e - s);
717		emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
718	}
719	return 0;
720}
721
722static int emit_branch_r32(const s8 *src1, const s8 *src2, s32 rvoff,
723			   struct rv_jit_context *ctx, const u8 op)
724{
725	int e, s = ctx->ninsns;
726	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
727	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
728
729	const s8 *rs1 = bpf_get_reg32(src1, tmp1, ctx);
730	const s8 *rs2 = bpf_get_reg32(src2, tmp2, ctx);
731
732	e = ctx->ninsns;
733	/* Adjust for extra insns. */
734	rvoff -= ninsns_rvoff(e - s);
735
736	if (emit_bcc(op, lo(rs1), lo(rs2), rvoff, ctx))
737		return -1;
738
739	return 0;
740}
741
742static void emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
743{
744	const s8 *r0 = bpf2rv32[BPF_REG_0];
745	const s8 *r5 = bpf2rv32[BPF_REG_5];
746	u32 upper = ((u32)addr + (1 << 11)) >> 12;
747	u32 lower = addr & 0xfff;
748
749	/* R1-R4 already in correct registers---need to push R5 to stack. */
750	emit(rv_addi(RV_REG_SP, RV_REG_SP, -16), ctx);
751	emit(rv_sw(RV_REG_SP, 0, lo(r5)), ctx);
752	emit(rv_sw(RV_REG_SP, 4, hi(r5)), ctx);
753
754	/* Backup TCC. */
755	emit(rv_addi(RV_REG_TCC_SAVED, RV_REG_TCC, 0), ctx);
756
757	/*
758	 * Use lui/jalr pair to jump to absolute address. Don't use emit_imm as
759	 * the number of emitted instructions should not depend on the value of
760	 * addr.
761	 */
762	emit(rv_lui(RV_REG_T1, upper), ctx);
763	emit(rv_jalr(RV_REG_RA, RV_REG_T1, lower), ctx);
764
765	/* Restore TCC. */
766	emit(rv_addi(RV_REG_TCC, RV_REG_TCC_SAVED, 0), ctx);
767
768	/* Set return value and restore stack. */
769	emit(rv_addi(lo(r0), RV_REG_A0, 0), ctx);
770	emit(rv_addi(hi(r0), RV_REG_A1, 0), ctx);
771	emit(rv_addi(RV_REG_SP, RV_REG_SP, 16), ctx);
772}
773
774static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
775{
776	/*
777	 * R1 -> &ctx
778	 * R2 -> &array
779	 * R3 -> index
780	 */
781	int tc_ninsn, off, start_insn = ctx->ninsns;
782	const s8 *arr_reg = bpf2rv32[BPF_REG_2];
783	const s8 *idx_reg = bpf2rv32[BPF_REG_3];
784
785	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
786		ctx->offset[0];
787
788	/* max_entries = array->map.max_entries; */
789	off = offsetof(struct bpf_array, map.max_entries);
790	if (is_12b_check(off, insn))
791		return -1;
792	emit(rv_lw(RV_REG_T1, off, lo(arr_reg)), ctx);
793
794	/*
795	 * if (index >= max_entries)
796	 *   goto out;
797	 */
798	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
799	emit_bcc(BPF_JGE, lo(idx_reg), RV_REG_T1, off, ctx);
800
801	/*
802	 * if (--tcc < 0)
803	 *   goto out;
804	 */
805	emit(rv_addi(RV_REG_TCC, RV_REG_TCC, -1), ctx);
806	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
807	emit_bcc(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
808
809	/*
810	 * prog = array->ptrs[index];
811	 * if (!prog)
812	 *   goto out;
813	 */
814	emit(rv_slli(RV_REG_T0, lo(idx_reg), 2), ctx);
815	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(arr_reg)), ctx);
816	off = offsetof(struct bpf_array, ptrs);
817	if (is_12b_check(off, insn))
818		return -1;
819	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
820	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
821	emit_bcc(BPF_JEQ, RV_REG_T0, RV_REG_ZERO, off, ctx);
822
823	/*
824	 * tcc = temp_tcc;
825	 * goto *(prog->bpf_func + 4);
826	 */
827	off = offsetof(struct bpf_prog, bpf_func);
828	if (is_12b_check(off, insn))
829		return -1;
830	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
831	/* Epilogue jumps to *(t0 + 4). */
832	__build_epilogue(true, ctx);
833	return 0;
834}
835
836static int emit_load_r64(const s8 *dst, const s8 *src, s16 off,
837			 struct rv_jit_context *ctx, const u8 size)
838{
839	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
840	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
841	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
842	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
843
844	emit_imm(RV_REG_T0, off, ctx);
845	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rs)), ctx);
846
847	switch (size) {
848	case BPF_B:
849		emit(rv_lbu(lo(rd), 0, RV_REG_T0), ctx);
850		if (!ctx->prog->aux->verifier_zext)
851			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
852		break;
853	case BPF_H:
854		emit(rv_lhu(lo(rd), 0, RV_REG_T0), ctx);
855		if (!ctx->prog->aux->verifier_zext)
856			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
857		break;
858	case BPF_W:
859		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
860		if (!ctx->prog->aux->verifier_zext)
861			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
862		break;
863	case BPF_DW:
864		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
865		emit(rv_lw(hi(rd), 4, RV_REG_T0), ctx);
866		break;
867	}
868
869	bpf_put_reg64(dst, rd, ctx);
870	return 0;
871}
872
873static int emit_store_r64(const s8 *dst, const s8 *src, s16 off,
874			  struct rv_jit_context *ctx, const u8 size,
875			  const u8 mode)
876{
877	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
878	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
879	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
880	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
881
882	if (mode == BPF_ATOMIC && size != BPF_W)
883		return -1;
884
885	emit_imm(RV_REG_T0, off, ctx);
886	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rd)), ctx);
887
888	switch (size) {
889	case BPF_B:
890		emit(rv_sb(RV_REG_T0, 0, lo(rs)), ctx);
891		break;
892	case BPF_H:
893		emit(rv_sh(RV_REG_T0, 0, lo(rs)), ctx);
894		break;
895	case BPF_W:
896		switch (mode) {
897		case BPF_MEM:
898			emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
899			break;
900		case BPF_ATOMIC: /* Only BPF_ADD supported */
901			emit(rv_amoadd_w(RV_REG_ZERO, lo(rs), RV_REG_T0, 0, 0),
902			     ctx);
903			break;
904		}
905		break;
906	case BPF_DW:
907		emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
908		emit(rv_sw(RV_REG_T0, 4, hi(rs)), ctx);
909		break;
910	}
911
912	return 0;
913}
914
915static void emit_rev16(const s8 rd, struct rv_jit_context *ctx)
916{
917	emit(rv_slli(rd, rd, 16), ctx);
918	emit(rv_slli(RV_REG_T1, rd, 8), ctx);
919	emit(rv_srli(rd, rd, 8), ctx);
920	emit(rv_add(RV_REG_T1, rd, RV_REG_T1), ctx);
921	emit(rv_srli(rd, RV_REG_T1, 16), ctx);
922}
923
924static void emit_rev32(const s8 rd, struct rv_jit_context *ctx)
925{
926	emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 0), ctx);
927	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
928	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
929	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
930	emit(rv_srli(rd, rd, 8), ctx);
931	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
932	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
933	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
934	emit(rv_srli(rd, rd, 8), ctx);
935	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
936	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
937	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
938	emit(rv_srli(rd, rd, 8), ctx);
939	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
940	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
941	emit(rv_addi(rd, RV_REG_T1, 0), ctx);
942}
943
944static void emit_zext64(const s8 *dst, struct rv_jit_context *ctx)
945{
946	const s8 *rd;
947	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
948
949	rd = bpf_get_reg64(dst, tmp1, ctx);
950	emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
951	bpf_put_reg64(dst, rd, ctx);
952}
953
954int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
955		      bool extra_pass)
956{
957	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
958		BPF_CLASS(insn->code) == BPF_JMP;
959	int s, e, rvoff, i = insn - ctx->prog->insnsi;
960	u8 code = insn->code;
961	s16 off = insn->off;
962	s32 imm = insn->imm;
963
964	const s8 *dst = bpf2rv32[insn->dst_reg];
965	const s8 *src = bpf2rv32[insn->src_reg];
966	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
967	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
968
969	switch (code) {
970	case BPF_ALU64 | BPF_MOV | BPF_X:
971
972	case BPF_ALU64 | BPF_ADD | BPF_X:
973	case BPF_ALU64 | BPF_ADD | BPF_K:
974
975	case BPF_ALU64 | BPF_SUB | BPF_X:
976	case BPF_ALU64 | BPF_SUB | BPF_K:
977
978	case BPF_ALU64 | BPF_AND | BPF_X:
979	case BPF_ALU64 | BPF_OR | BPF_X:
980	case BPF_ALU64 | BPF_XOR | BPF_X:
981
982	case BPF_ALU64 | BPF_MUL | BPF_X:
983	case BPF_ALU64 | BPF_MUL | BPF_K:
984
985	case BPF_ALU64 | BPF_LSH | BPF_X:
986	case BPF_ALU64 | BPF_RSH | BPF_X:
987	case BPF_ALU64 | BPF_ARSH | BPF_X:
988		if (BPF_SRC(code) == BPF_K) {
989			emit_imm32(tmp2, imm, ctx);
990			src = tmp2;
991		}
992		emit_alu_r64(dst, src, ctx, BPF_OP(code));
993		break;
994
995	case BPF_ALU64 | BPF_NEG:
996		emit_alu_r64(dst, tmp2, ctx, BPF_OP(code));
997		break;
998
999	case BPF_ALU64 | BPF_DIV | BPF_X:
1000	case BPF_ALU64 | BPF_DIV | BPF_K:
1001	case BPF_ALU64 | BPF_MOD | BPF_X:
1002	case BPF_ALU64 | BPF_MOD | BPF_K:
1003		goto notsupported;
1004
1005	case BPF_ALU64 | BPF_MOV | BPF_K:
1006	case BPF_ALU64 | BPF_AND | BPF_K:
1007	case BPF_ALU64 | BPF_OR | BPF_K:
1008	case BPF_ALU64 | BPF_XOR | BPF_K:
1009	case BPF_ALU64 | BPF_LSH | BPF_K:
1010	case BPF_ALU64 | BPF_RSH | BPF_K:
1011	case BPF_ALU64 | BPF_ARSH | BPF_K:
1012		emit_alu_i64(dst, imm, ctx, BPF_OP(code));
1013		break;
1014
1015	case BPF_ALU | BPF_MOV | BPF_X:
1016		if (imm == 1) {
1017			/* Special mov32 for zext. */
1018			emit_zext64(dst, ctx);
1019			break;
1020		}
1021		fallthrough;
1022
1023	case BPF_ALU | BPF_ADD | BPF_X:
1024	case BPF_ALU | BPF_SUB | BPF_X:
1025	case BPF_ALU | BPF_AND | BPF_X:
1026	case BPF_ALU | BPF_OR | BPF_X:
1027	case BPF_ALU | BPF_XOR | BPF_X:
1028
1029	case BPF_ALU | BPF_MUL | BPF_X:
1030	case BPF_ALU | BPF_MUL | BPF_K:
1031
1032	case BPF_ALU | BPF_DIV | BPF_X:
1033	case BPF_ALU | BPF_DIV | BPF_K:
1034
1035	case BPF_ALU | BPF_MOD | BPF_X:
1036	case BPF_ALU | BPF_MOD | BPF_K:
1037
1038	case BPF_ALU | BPF_LSH | BPF_X:
1039	case BPF_ALU | BPF_RSH | BPF_X:
1040	case BPF_ALU | BPF_ARSH | BPF_X:
1041		if (BPF_SRC(code) == BPF_K) {
1042			emit_imm32(tmp2, imm, ctx);
1043			src = tmp2;
1044		}
1045		emit_alu_r32(dst, src, ctx, BPF_OP(code));
1046		break;
1047
1048	case BPF_ALU | BPF_MOV | BPF_K:
1049	case BPF_ALU | BPF_ADD | BPF_K:
1050	case BPF_ALU | BPF_SUB | BPF_K:
1051	case BPF_ALU | BPF_AND | BPF_K:
1052	case BPF_ALU | BPF_OR | BPF_K:
1053	case BPF_ALU | BPF_XOR | BPF_K:
1054	case BPF_ALU | BPF_LSH | BPF_K:
1055	case BPF_ALU | BPF_RSH | BPF_K:
1056	case BPF_ALU | BPF_ARSH | BPF_K:
1057		/*
1058		 * mul,div,mod are handled in the BPF_X case since there are
1059		 * no RISC-V I-type equivalents.
1060		 */
1061		emit_alu_i32(dst, imm, ctx, BPF_OP(code));
1062		break;
1063
1064	case BPF_ALU | BPF_NEG:
1065		/*
1066		 * src is ignored---choose tmp2 as a dummy register since it
1067		 * is not on the stack.
1068		 */
1069		emit_alu_r32(dst, tmp2, ctx, BPF_OP(code));
1070		break;
1071
1072	case BPF_ALU | BPF_END | BPF_FROM_LE:
1073	{
1074		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1075
1076		switch (imm) {
1077		case 16:
1078			emit(rv_slli(lo(rd), lo(rd), 16), ctx);
1079			emit(rv_srli(lo(rd), lo(rd), 16), ctx);
1080			fallthrough;
1081		case 32:
1082			if (!ctx->prog->aux->verifier_zext)
1083				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1084			break;
1085		case 64:
1086			/* Do nothing. */
1087			break;
1088		default:
1089			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1090			return -1;
1091		}
1092
1093		bpf_put_reg64(dst, rd, ctx);
1094		break;
1095	}
1096
1097	case BPF_ALU | BPF_END | BPF_FROM_BE:
1098	{
1099		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1100
1101		switch (imm) {
1102		case 16:
1103			emit_rev16(lo(rd), ctx);
1104			if (!ctx->prog->aux->verifier_zext)
1105				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1106			break;
1107		case 32:
1108			emit_rev32(lo(rd), ctx);
1109			if (!ctx->prog->aux->verifier_zext)
1110				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1111			break;
1112		case 64:
1113			/* Swap upper and lower halves. */
1114			emit(rv_addi(RV_REG_T0, lo(rd), 0), ctx);
1115			emit(rv_addi(lo(rd), hi(rd), 0), ctx);
1116			emit(rv_addi(hi(rd), RV_REG_T0, 0), ctx);
1117
1118			/* Swap each half. */
1119			emit_rev32(lo(rd), ctx);
1120			emit_rev32(hi(rd), ctx);
1121			break;
1122		default:
1123			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1124			return -1;
1125		}
1126
1127		bpf_put_reg64(dst, rd, ctx);
1128		break;
1129	}
1130
1131	case BPF_JMP | BPF_JA:
1132		rvoff = rv_offset(i, off, ctx);
1133		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1134		break;
1135
1136	case BPF_JMP | BPF_CALL:
1137	{
1138		bool fixed;
1139		int ret;
1140		u64 addr;
1141
1142		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1143					    &fixed);
1144		if (ret < 0)
1145			return ret;
1146		emit_call(fixed, addr, ctx);
1147		break;
1148	}
1149
1150	case BPF_JMP | BPF_TAIL_CALL:
1151		if (emit_bpf_tail_call(i, ctx))
1152			return -1;
1153		break;
1154
1155	case BPF_JMP | BPF_JEQ | BPF_X:
1156	case BPF_JMP | BPF_JEQ | BPF_K:
1157	case BPF_JMP32 | BPF_JEQ | BPF_X:
1158	case BPF_JMP32 | BPF_JEQ | BPF_K:
1159
1160	case BPF_JMP | BPF_JNE | BPF_X:
1161	case BPF_JMP | BPF_JNE | BPF_K:
1162	case BPF_JMP32 | BPF_JNE | BPF_X:
1163	case BPF_JMP32 | BPF_JNE | BPF_K:
1164
1165	case BPF_JMP | BPF_JLE | BPF_X:
1166	case BPF_JMP | BPF_JLE | BPF_K:
1167	case BPF_JMP32 | BPF_JLE | BPF_X:
1168	case BPF_JMP32 | BPF_JLE | BPF_K:
1169
1170	case BPF_JMP | BPF_JLT | BPF_X:
1171	case BPF_JMP | BPF_JLT | BPF_K:
1172	case BPF_JMP32 | BPF_JLT | BPF_X:
1173	case BPF_JMP32 | BPF_JLT | BPF_K:
1174
1175	case BPF_JMP | BPF_JGE | BPF_X:
1176	case BPF_JMP | BPF_JGE | BPF_K:
1177	case BPF_JMP32 | BPF_JGE | BPF_X:
1178	case BPF_JMP32 | BPF_JGE | BPF_K:
1179
1180	case BPF_JMP | BPF_JGT | BPF_X:
1181	case BPF_JMP | BPF_JGT | BPF_K:
1182	case BPF_JMP32 | BPF_JGT | BPF_X:
1183	case BPF_JMP32 | BPF_JGT | BPF_K:
1184
1185	case BPF_JMP | BPF_JSLE | BPF_X:
1186	case BPF_JMP | BPF_JSLE | BPF_K:
1187	case BPF_JMP32 | BPF_JSLE | BPF_X:
1188	case BPF_JMP32 | BPF_JSLE | BPF_K:
1189
1190	case BPF_JMP | BPF_JSLT | BPF_X:
1191	case BPF_JMP | BPF_JSLT | BPF_K:
1192	case BPF_JMP32 | BPF_JSLT | BPF_X:
1193	case BPF_JMP32 | BPF_JSLT | BPF_K:
1194
1195	case BPF_JMP | BPF_JSGE | BPF_X:
1196	case BPF_JMP | BPF_JSGE | BPF_K:
1197	case BPF_JMP32 | BPF_JSGE | BPF_X:
1198	case BPF_JMP32 | BPF_JSGE | BPF_K:
1199
1200	case BPF_JMP | BPF_JSGT | BPF_X:
1201	case BPF_JMP | BPF_JSGT | BPF_K:
1202	case BPF_JMP32 | BPF_JSGT | BPF_X:
1203	case BPF_JMP32 | BPF_JSGT | BPF_K:
1204
1205	case BPF_JMP | BPF_JSET | BPF_X:
1206	case BPF_JMP | BPF_JSET | BPF_K:
1207	case BPF_JMP32 | BPF_JSET | BPF_X:
1208	case BPF_JMP32 | BPF_JSET | BPF_K:
1209		rvoff = rv_offset(i, off, ctx);
1210		if (BPF_SRC(code) == BPF_K) {
1211			s = ctx->ninsns;
1212			emit_imm32(tmp2, imm, ctx);
1213			src = tmp2;
1214			e = ctx->ninsns;
1215			rvoff -= ninsns_rvoff(e - s);
1216		}
1217
1218		if (is64)
1219			emit_branch_r64(dst, src, rvoff, ctx, BPF_OP(code));
1220		else
1221			emit_branch_r32(dst, src, rvoff, ctx, BPF_OP(code));
1222		break;
1223
1224	case BPF_JMP | BPF_EXIT:
1225		if (i == ctx->prog->len - 1)
1226			break;
1227
1228		rvoff = epilogue_offset(ctx);
1229		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1230		break;
1231
1232	case BPF_LD | BPF_IMM | BPF_DW:
1233	{
1234		struct bpf_insn insn1 = insn[1];
1235		s32 imm_lo = imm;
1236		s32 imm_hi = insn1.imm;
1237		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1238
1239		emit_imm64(rd, imm_hi, imm_lo, ctx);
1240		bpf_put_reg64(dst, rd, ctx);
1241		return 1;
1242	}
1243
1244	case BPF_LDX | BPF_MEM | BPF_B:
1245	case BPF_LDX | BPF_MEM | BPF_H:
1246	case BPF_LDX | BPF_MEM | BPF_W:
1247	case BPF_LDX | BPF_MEM | BPF_DW:
1248		if (emit_load_r64(dst, src, off, ctx, BPF_SIZE(code)))
1249			return -1;
1250		break;
1251
1252	/* speculation barrier */
1253	case BPF_ST | BPF_NOSPEC:
1254		break;
1255
1256	case BPF_ST | BPF_MEM | BPF_B:
1257	case BPF_ST | BPF_MEM | BPF_H:
1258	case BPF_ST | BPF_MEM | BPF_W:
1259	case BPF_ST | BPF_MEM | BPF_DW:
1260
1261	case BPF_STX | BPF_MEM | BPF_B:
1262	case BPF_STX | BPF_MEM | BPF_H:
1263	case BPF_STX | BPF_MEM | BPF_W:
1264	case BPF_STX | BPF_MEM | BPF_DW:
1265		if (BPF_CLASS(code) == BPF_ST) {
1266			emit_imm32(tmp2, imm, ctx);
1267			src = tmp2;
1268		}
1269
1270		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1271				   BPF_MODE(code)))
1272			return -1;
1273		break;
1274
1275	case BPF_STX | BPF_ATOMIC | BPF_W:
1276		if (insn->imm != BPF_ADD) {
1277			pr_info_once(
1278				"bpf-jit: not supported: atomic operation %02x ***\n",
1279				insn->imm);
1280			return -EFAULT;
1281		}
1282
1283		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1284				   BPF_MODE(code)))
1285			return -1;
1286		break;
1287
1288	/* No hardware support for 8-byte atomics in RV32. */
1289	case BPF_STX | BPF_ATOMIC | BPF_DW:
1290		/* Fallthrough. */
1291
1292notsupported:
1293		pr_info_once("bpf-jit: not supported: opcode %02x ***\n", code);
1294		return -EFAULT;
1295
1296	default:
1297		pr_err("bpf-jit: unknown opcode %02x\n", code);
1298		return -EINVAL;
1299	}
1300
1301	return 0;
1302}
1303
1304void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
1305{
1306	const s8 *fp = bpf2rv32[BPF_REG_FP];
1307	const s8 *r1 = bpf2rv32[BPF_REG_1];
1308	int stack_adjust = 0;
1309	int bpf_stack_adjust =
1310		round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
1311
1312	/* Make space for callee-saved registers. */
1313	stack_adjust += NR_SAVED_REGISTERS * sizeof(u32);
1314	/* Make space for BPF registers on stack. */
1315	stack_adjust += BPF_JIT_SCRATCH_REGS * sizeof(u32);
1316	/* Make space for BPF stack. */
1317	stack_adjust += bpf_stack_adjust;
1318	/* Round up for stack alignment. */
1319	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
1320
1321	/*
1322	 * The first instruction sets the tail-call-counter (TCC) register.
1323	 * This instruction is skipped by tail calls.
1324	 */
1325	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1326
1327	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
1328
1329	/* Save callee-save registers. */
1330	emit(rv_sw(RV_REG_SP, stack_adjust - 4, RV_REG_RA), ctx);
1331	emit(rv_sw(RV_REG_SP, stack_adjust - 8, RV_REG_FP), ctx);
1332	emit(rv_sw(RV_REG_SP, stack_adjust - 12, RV_REG_S1), ctx);
1333	emit(rv_sw(RV_REG_SP, stack_adjust - 16, RV_REG_S2), ctx);
1334	emit(rv_sw(RV_REG_SP, stack_adjust - 20, RV_REG_S3), ctx);
1335	emit(rv_sw(RV_REG_SP, stack_adjust - 24, RV_REG_S4), ctx);
1336	emit(rv_sw(RV_REG_SP, stack_adjust - 28, RV_REG_S5), ctx);
1337	emit(rv_sw(RV_REG_SP, stack_adjust - 32, RV_REG_S6), ctx);
1338	emit(rv_sw(RV_REG_SP, stack_adjust - 36, RV_REG_S7), ctx);
1339
1340	/* Set fp: used as the base address for stacked BPF registers. */
1341	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
1342
1343	/* Set up BPF frame pointer. */
1344	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
1345	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
1346
1347	/* Set up BPF context pointer. */
1348	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
1349	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);
1350
1351	ctx->stack_size = stack_adjust;
1352}
1353
1354void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1355{
1356	__build_epilogue(false, ctx);
1357}
1358