cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

bpf_jit_comp.c (40123B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * BPF JIT compiler for ARM64
      4 *
      5 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
      6 */
      7
      8#define pr_fmt(fmt) "bpf_jit: " fmt
      9
     10#include <linux/bitfield.h>
     11#include <linux/bpf.h>
     12#include <linux/filter.h>
     13#include <linux/printk.h>
     14#include <linux/slab.h>
     15
     16#include <asm/asm-extable.h>
     17#include <asm/byteorder.h>
     18#include <asm/cacheflush.h>
     19#include <asm/debug-monitors.h>
     20#include <asm/insn.h>
     21#include <asm/set_memory.h>
     22
     23#include "bpf_jit.h"
     24
     25#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
     26#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
     27#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
     28#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
     29#define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
     30
     31#define check_imm(bits, imm) do {				\
     32	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
     33	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
     34		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
     35			i, imm, imm);				\
     36		return -EINVAL;					\
     37	}							\
     38} while (0)
     39#define check_imm19(imm) check_imm(19, imm)
     40#define check_imm26(imm) check_imm(26, imm)
     41
     42/* Map BPF registers to A64 registers */
     43static const int bpf2a64[] = {
     44	/* return value from in-kernel function, and exit value from eBPF */
     45	[BPF_REG_0] = A64_R(7),
     46	/* arguments from eBPF program to in-kernel function */
     47	[BPF_REG_1] = A64_R(0),
     48	[BPF_REG_2] = A64_R(1),
     49	[BPF_REG_3] = A64_R(2),
     50	[BPF_REG_4] = A64_R(3),
     51	[BPF_REG_5] = A64_R(4),
     52	/* callee saved registers that in-kernel function will preserve */
     53	[BPF_REG_6] = A64_R(19),
     54	[BPF_REG_7] = A64_R(20),
     55	[BPF_REG_8] = A64_R(21),
     56	[BPF_REG_9] = A64_R(22),
     57	/* read-only frame pointer to access stack */
     58	[BPF_REG_FP] = A64_R(25),
     59	/* temporary registers for BPF JIT */
     60	[TMP_REG_1] = A64_R(10),
     61	[TMP_REG_2] = A64_R(11),
     62	[TMP_REG_3] = A64_R(12),
     63	/* tail_call_cnt */
     64	[TCALL_CNT] = A64_R(26),
     65	/* temporary register for blinding constants */
     66	[BPF_REG_AX] = A64_R(9),
     67	[FP_BOTTOM] = A64_R(27),
     68};
     69
     70struct jit_ctx {
     71	const struct bpf_prog *prog;
     72	int idx;
     73	int epilogue_offset;
     74	int *offset;
     75	int exentry_idx;
     76	__le32 *image;
     77	u32 stack_size;
     78	int fpb_offset;
     79};
     80
     81static inline void emit(const u32 insn, struct jit_ctx *ctx)
     82{
     83	if (ctx->image != NULL)
     84		ctx->image[ctx->idx] = cpu_to_le32(insn);
     85
     86	ctx->idx++;
     87}
     88
     89static inline void emit_a64_mov_i(const int is64, const int reg,
     90				  const s32 val, struct jit_ctx *ctx)
     91{
     92	u16 hi = val >> 16;
     93	u16 lo = val & 0xffff;
     94
     95	if (hi & 0x8000) {
     96		if (hi == 0xffff) {
     97			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
     98		} else {
     99			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
    100			if (lo != 0xffff)
    101				emit(A64_MOVK(is64, reg, lo, 0), ctx);
    102		}
    103	} else {
    104		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
    105		if (hi)
    106			emit(A64_MOVK(is64, reg, hi, 16), ctx);
    107	}
    108}
    109
    110static int i64_i16_blocks(const u64 val, bool inverse)
    111{
    112	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    113	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    114	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    115	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
    116}
    117
    118static inline void emit_a64_mov_i64(const int reg, const u64 val,
    119				    struct jit_ctx *ctx)
    120{
    121	u64 nrm_tmp = val, rev_tmp = ~val;
    122	bool inverse;
    123	int shift;
    124
    125	if (!(nrm_tmp >> 32))
    126		return emit_a64_mov_i(0, reg, (u32)val, ctx);
    127
    128	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
    129	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
    130					  (fls64(nrm_tmp) - 1)), 16), 0);
    131	if (inverse)
    132		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
    133	else
    134		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
    135	shift -= 16;
    136	while (shift >= 0) {
    137		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
    138			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
    139		shift -= 16;
    140	}
    141}
    142
    143/*
    144 * Kernel addresses in the vmalloc space use at most 48 bits, and the
    145 * remaining bits are guaranteed to be 0x1. So we can compose the address
    146 * with a fixed length movn/movk/movk sequence.
    147 */
    148static inline void emit_addr_mov_i64(const int reg, const u64 val,
    149				     struct jit_ctx *ctx)
    150{
    151	u64 tmp = val;
    152	int shift = 0;
    153
    154	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
    155	while (shift < 32) {
    156		tmp >>= 16;
    157		shift += 16;
    158		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
    159	}
    160}
    161
    162static inline int bpf2a64_offset(int bpf_insn, int off,
    163				 const struct jit_ctx *ctx)
    164{
    165	/* BPF JMP offset is relative to the next instruction */
    166	bpf_insn++;
    167	/*
    168	 * Whereas arm64 branch instructions encode the offset
    169	 * from the branch itself, so we must subtract 1 from the
    170	 * instruction offset.
    171	 */
    172	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
    173}
    174
    175static void jit_fill_hole(void *area, unsigned int size)
    176{
    177	__le32 *ptr;
    178	/* We are guaranteed to have aligned memory. */
    179	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
    180		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
    181}
    182
    183static inline int epilogue_offset(const struct jit_ctx *ctx)
    184{
    185	int to = ctx->epilogue_offset;
    186	int from = ctx->idx;
    187
    188	return to - from;
    189}
    190
    191static bool is_addsub_imm(u32 imm)
    192{
    193	/* Either imm12 or shifted imm12. */
    194	return !(imm & ~0xfff) || !(imm & ~0xfff000);
    195}
    196
    197/*
    198 * There are 3 types of AArch64 LDR/STR (immediate) instruction:
    199 * Post-index, Pre-index, Unsigned offset.
    200 *
    201 * For BPF ldr/str, the "unsigned offset" type is sufficient.
    202 *
    203 * "Unsigned offset" type LDR(immediate) format:
    204 *
    205 *    3                   2                   1                   0
    206 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
    207 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    208 * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
    209 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    210 * scale
    211 *
    212 * "Unsigned offset" type STR(immediate) format:
    213 *    3                   2                   1                   0
    214 *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
    215 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    216 * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
    217 * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    218 * scale
    219 *
    220 * The offset is calculated from imm12 and scale in the following way:
    221 *
    222 * offset = (u64)imm12 << scale
    223 */
    224static bool is_lsi_offset(int offset, int scale)
    225{
    226	if (offset < 0)
    227		return false;
    228
    229	if (offset > (0xFFF << scale))
    230		return false;
    231
    232	if (offset & ((1 << scale) - 1))
    233		return false;
    234
    235	return true;
    236}
    237
    238/* Tail call offset to jump into */
    239#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) || \
    240	IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)
    241#define PROLOGUE_OFFSET 9
    242#else
    243#define PROLOGUE_OFFSET 8
    244#endif
    245
    246static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
    247{
    248	const struct bpf_prog *prog = ctx->prog;
    249	const u8 r6 = bpf2a64[BPF_REG_6];
    250	const u8 r7 = bpf2a64[BPF_REG_7];
    251	const u8 r8 = bpf2a64[BPF_REG_8];
    252	const u8 r9 = bpf2a64[BPF_REG_9];
    253	const u8 fp = bpf2a64[BPF_REG_FP];
    254	const u8 tcc = bpf2a64[TCALL_CNT];
    255	const u8 fpb = bpf2a64[FP_BOTTOM];
    256	const int idx0 = ctx->idx;
    257	int cur_offset;
    258
    259	/*
    260	 * BPF prog stack layout
    261	 *
    262	 *                         high
    263	 * original A64_SP =>   0:+-----+ BPF prologue
    264	 *                        |FP/LR|
    265	 * current A64_FP =>  -16:+-----+
    266	 *                        | ... | callee saved registers
    267	 * BPF fp register => -64:+-----+ <= (BPF_FP)
    268	 *                        |     |
    269	 *                        | ... | BPF prog stack
    270	 *                        |     |
    271	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
    272	 *                        |RSVD | padding
    273	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
    274	 *                        |     |
    275	 *                        | ... | Function call stack
    276	 *                        |     |
    277	 *                        +-----+
    278	 *                          low
    279	 *
    280	 */
    281
    282	/* Sign lr */
    283	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
    284		emit(A64_PACIASP, ctx);
    285	/* BTI landing pad */
    286	else if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
    287		emit(A64_BTI_C, ctx);
    288
    289	/* Save FP and LR registers to stay align with ARM64 AAPCS */
    290	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
    291	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
    292
    293	/* Save callee-saved registers */
    294	emit(A64_PUSH(r6, r7, A64_SP), ctx);
    295	emit(A64_PUSH(r8, r9, A64_SP), ctx);
    296	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
    297	emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
    298
    299	/* Set up BPF prog stack base register */
    300	emit(A64_MOV(1, fp, A64_SP), ctx);
    301
    302	if (!ebpf_from_cbpf) {
    303		/* Initialize tail_call_cnt */
    304		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
    305
    306		cur_offset = ctx->idx - idx0;
    307		if (cur_offset != PROLOGUE_OFFSET) {
    308			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
    309				    cur_offset, PROLOGUE_OFFSET);
    310			return -1;
    311		}
    312
    313		/* BTI landing pad for the tail call, done with a BR */
    314		if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
    315			emit(A64_BTI_J, ctx);
    316	}
    317
    318	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
    319
    320	/* Stack must be multiples of 16B */
    321	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
    322
    323	/* Set up function call stack */
    324	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    325	return 0;
    326}
    327
    328static int out_offset = -1; /* initialized on the first pass of build_body() */
    329static int emit_bpf_tail_call(struct jit_ctx *ctx)
    330{
    331	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
    332	const u8 r2 = bpf2a64[BPF_REG_2];
    333	const u8 r3 = bpf2a64[BPF_REG_3];
    334
    335	const u8 tmp = bpf2a64[TMP_REG_1];
    336	const u8 prg = bpf2a64[TMP_REG_2];
    337	const u8 tcc = bpf2a64[TCALL_CNT];
    338	const int idx0 = ctx->idx;
    339#define cur_offset (ctx->idx - idx0)
    340#define jmp_offset (out_offset - (cur_offset))
    341	size_t off;
    342
    343	/* if (index >= array->map.max_entries)
    344	 *     goto out;
    345	 */
    346	off = offsetof(struct bpf_array, map.max_entries);
    347	emit_a64_mov_i64(tmp, off, ctx);
    348	emit(A64_LDR32(tmp, r2, tmp), ctx);
    349	emit(A64_MOV(0, r3, r3), ctx);
    350	emit(A64_CMP(0, r3, tmp), ctx);
    351	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
    352
    353	/*
    354	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
    355	 *     goto out;
    356	 * tail_call_cnt++;
    357	 */
    358	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
    359	emit(A64_CMP(1, tcc, tmp), ctx);
    360	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
    361	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
    362
    363	/* prog = array->ptrs[index];
    364	 * if (prog == NULL)
    365	 *     goto out;
    366	 */
    367	off = offsetof(struct bpf_array, ptrs);
    368	emit_a64_mov_i64(tmp, off, ctx);
    369	emit(A64_ADD(1, tmp, r2, tmp), ctx);
    370	emit(A64_LSL(1, prg, r3, 3), ctx);
    371	emit(A64_LDR64(prg, tmp, prg), ctx);
    372	emit(A64_CBZ(1, prg, jmp_offset), ctx);
    373
    374	/* goto *(prog->bpf_func + prologue_offset); */
    375	off = offsetof(struct bpf_prog, bpf_func);
    376	emit_a64_mov_i64(tmp, off, ctx);
    377	emit(A64_LDR64(tmp, prg, tmp), ctx);
    378	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
    379	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    380	emit(A64_BR(tmp), ctx);
    381
    382	/* out: */
    383	if (out_offset == -1)
    384		out_offset = cur_offset;
    385	if (cur_offset != out_offset) {
    386		pr_err_once("tail_call out_offset = %d, expected %d!\n",
    387			    cur_offset, out_offset);
    388		return -1;
    389	}
    390	return 0;
    391#undef cur_offset
    392#undef jmp_offset
    393}
    394
    395#ifdef CONFIG_ARM64_LSE_ATOMICS
    396static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
    397{
    398	const u8 code = insn->code;
    399	const u8 dst = bpf2a64[insn->dst_reg];
    400	const u8 src = bpf2a64[insn->src_reg];
    401	const u8 tmp = bpf2a64[TMP_REG_1];
    402	const u8 tmp2 = bpf2a64[TMP_REG_2];
    403	const bool isdw = BPF_SIZE(code) == BPF_DW;
    404	const s16 off = insn->off;
    405	u8 reg;
    406
    407	if (!off) {
    408		reg = dst;
    409	} else {
    410		emit_a64_mov_i(1, tmp, off, ctx);
    411		emit(A64_ADD(1, tmp, tmp, dst), ctx);
    412		reg = tmp;
    413	}
    414
    415	switch (insn->imm) {
    416	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
    417	case BPF_ADD:
    418		emit(A64_STADD(isdw, reg, src), ctx);
    419		break;
    420	case BPF_AND:
    421		emit(A64_MVN(isdw, tmp2, src), ctx);
    422		emit(A64_STCLR(isdw, reg, tmp2), ctx);
    423		break;
    424	case BPF_OR:
    425		emit(A64_STSET(isdw, reg, src), ctx);
    426		break;
    427	case BPF_XOR:
    428		emit(A64_STEOR(isdw, reg, src), ctx);
    429		break;
    430	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
    431	case BPF_ADD | BPF_FETCH:
    432		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
    433		break;
    434	case BPF_AND | BPF_FETCH:
    435		emit(A64_MVN(isdw, tmp2, src), ctx);
    436		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
    437		break;
    438	case BPF_OR | BPF_FETCH:
    439		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
    440		break;
    441	case BPF_XOR | BPF_FETCH:
    442		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
    443		break;
    444	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
    445	case BPF_XCHG:
    446		emit(A64_SWPAL(isdw, src, reg, src), ctx);
    447		break;
    448	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
    449	case BPF_CMPXCHG:
    450		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
    451		break;
    452	default:
    453		pr_err_once("unknown atomic op code %02x\n", insn->imm);
    454		return -EINVAL;
    455	}
    456
    457	return 0;
    458}
    459#else
    460static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
    461{
    462	return -EINVAL;
    463}
    464#endif
    465
    466static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
    467{
    468	const u8 code = insn->code;
    469	const u8 dst = bpf2a64[insn->dst_reg];
    470	const u8 src = bpf2a64[insn->src_reg];
    471	const u8 tmp = bpf2a64[TMP_REG_1];
    472	const u8 tmp2 = bpf2a64[TMP_REG_2];
    473	const u8 tmp3 = bpf2a64[TMP_REG_3];
    474	const int i = insn - ctx->prog->insnsi;
    475	const s32 imm = insn->imm;
    476	const s16 off = insn->off;
    477	const bool isdw = BPF_SIZE(code) == BPF_DW;
    478	u8 reg;
    479	s32 jmp_offset;
    480
    481	if (!off) {
    482		reg = dst;
    483	} else {
    484		emit_a64_mov_i(1, tmp, off, ctx);
    485		emit(A64_ADD(1, tmp, tmp, dst), ctx);
    486		reg = tmp;
    487	}
    488
    489	if (imm == BPF_ADD || imm == BPF_AND ||
    490	    imm == BPF_OR || imm == BPF_XOR) {
    491		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
    492		emit(A64_LDXR(isdw, tmp2, reg), ctx);
    493		if (imm == BPF_ADD)
    494			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
    495		else if (imm == BPF_AND)
    496			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
    497		else if (imm == BPF_OR)
    498			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
    499		else
    500			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
    501		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
    502		jmp_offset = -3;
    503		check_imm19(jmp_offset);
    504		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
    505	} else if (imm == (BPF_ADD | BPF_FETCH) ||
    506		   imm == (BPF_AND | BPF_FETCH) ||
    507		   imm == (BPF_OR | BPF_FETCH) ||
    508		   imm == (BPF_XOR | BPF_FETCH)) {
    509		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
    510		const u8 ax = bpf2a64[BPF_REG_AX];
    511
    512		emit(A64_MOV(isdw, ax, src), ctx);
    513		emit(A64_LDXR(isdw, src, reg), ctx);
    514		if (imm == (BPF_ADD | BPF_FETCH))
    515			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
    516		else if (imm == (BPF_AND | BPF_FETCH))
    517			emit(A64_AND(isdw, tmp2, src, ax), ctx);
    518		else if (imm == (BPF_OR | BPF_FETCH))
    519			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
    520		else
    521			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
    522		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
    523		jmp_offset = -3;
    524		check_imm19(jmp_offset);
    525		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
    526		emit(A64_DMB_ISH, ctx);
    527	} else if (imm == BPF_XCHG) {
    528		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
    529		emit(A64_MOV(isdw, tmp2, src), ctx);
    530		emit(A64_LDXR(isdw, src, reg), ctx);
    531		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
    532		jmp_offset = -2;
    533		check_imm19(jmp_offset);
    534		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
    535		emit(A64_DMB_ISH, ctx);
    536	} else if (imm == BPF_CMPXCHG) {
    537		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
    538		const u8 r0 = bpf2a64[BPF_REG_0];
    539
    540		emit(A64_MOV(isdw, tmp2, r0), ctx);
    541		emit(A64_LDXR(isdw, r0, reg), ctx);
    542		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
    543		jmp_offset = 4;
    544		check_imm19(jmp_offset);
    545		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
    546		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
    547		jmp_offset = -4;
    548		check_imm19(jmp_offset);
    549		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
    550		emit(A64_DMB_ISH, ctx);
    551	} else {
    552		pr_err_once("unknown atomic op code %02x\n", imm);
    553		return -EINVAL;
    554	}
    555
    556	return 0;
    557}
    558
    559static void build_epilogue(struct jit_ctx *ctx)
    560{
    561	const u8 r0 = bpf2a64[BPF_REG_0];
    562	const u8 r6 = bpf2a64[BPF_REG_6];
    563	const u8 r7 = bpf2a64[BPF_REG_7];
    564	const u8 r8 = bpf2a64[BPF_REG_8];
    565	const u8 r9 = bpf2a64[BPF_REG_9];
    566	const u8 fp = bpf2a64[BPF_REG_FP];
    567	const u8 fpb = bpf2a64[FP_BOTTOM];
    568
    569	/* We're done with BPF stack */
    570	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    571
    572	/* Restore x27 and x28 */
    573	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
    574	/* Restore fs (x25) and x26 */
    575	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
    576
    577	/* Restore callee-saved register */
    578	emit(A64_POP(r8, r9, A64_SP), ctx);
    579	emit(A64_POP(r6, r7, A64_SP), ctx);
    580
    581	/* Restore FP/LR registers */
    582	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
    583
    584	/* Set return value */
    585	emit(A64_MOV(1, A64_R(0), r0), ctx);
    586
    587	/* Authenticate lr */
    588	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
    589		emit(A64_AUTIASP, ctx);
    590
    591	emit(A64_RET(A64_LR), ctx);
    592}
    593
    594#define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
    595#define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
    596
    597bool ex_handler_bpf(const struct exception_table_entry *ex,
    598		    struct pt_regs *regs)
    599{
    600	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
    601	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
    602
    603	regs->regs[dst_reg] = 0;
    604	regs->pc = (unsigned long)&ex->fixup - offset;
    605	return true;
    606}
    607
    608/* For accesses to BTF pointers, add an entry to the exception table */
    609static int add_exception_handler(const struct bpf_insn *insn,
    610				 struct jit_ctx *ctx,
    611				 int dst_reg)
    612{
    613	off_t offset;
    614	unsigned long pc;
    615	struct exception_table_entry *ex;
    616
    617	if (!ctx->image)
    618		/* First pass */
    619		return 0;
    620
    621	if (BPF_MODE(insn->code) != BPF_PROBE_MEM)
    622		return 0;
    623
    624	if (!ctx->prog->aux->extable ||
    625	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
    626		return -EINVAL;
    627
    628	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
    629	pc = (unsigned long)&ctx->image[ctx->idx - 1];
    630
    631	offset = pc - (long)&ex->insn;
    632	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
    633		return -ERANGE;
    634	ex->insn = offset;
    635
    636	/*
    637	 * Since the extable follows the program, the fixup offset is always
    638	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
    639	 * to keep things simple, and put the destination register in the upper
    640	 * bits. We don't need to worry about buildtime or runtime sort
    641	 * modifying the upper bits because the table is already sorted, and
    642	 * isn't part of the main exception table.
    643	 */
    644	offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
    645	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
    646		return -ERANGE;
    647
    648	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
    649		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
    650
    651	ex->type = EX_TYPE_BPF;
    652
    653	ctx->exentry_idx++;
    654	return 0;
    655}
    656
    657/* JITs an eBPF instruction.
    658 * Returns:
    659 * 0  - successfully JITed an 8-byte eBPF instruction.
    660 * >0 - successfully JITed a 16-byte eBPF instruction.
    661 * <0 - failed to JIT.
    662 */
    663static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
    664		      bool extra_pass)
    665{
    666	const u8 code = insn->code;
    667	const u8 dst = bpf2a64[insn->dst_reg];
    668	const u8 src = bpf2a64[insn->src_reg];
    669	const u8 tmp = bpf2a64[TMP_REG_1];
    670	const u8 tmp2 = bpf2a64[TMP_REG_2];
    671	const u8 fp = bpf2a64[BPF_REG_FP];
    672	const u8 fpb = bpf2a64[FP_BOTTOM];
    673	const s16 off = insn->off;
    674	const s32 imm = insn->imm;
    675	const int i = insn - ctx->prog->insnsi;
    676	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
    677			  BPF_CLASS(code) == BPF_JMP;
    678	u8 jmp_cond;
    679	s32 jmp_offset;
    680	u32 a64_insn;
    681	u8 src_adj;
    682	u8 dst_adj;
    683	int off_adj;
    684	int ret;
    685
    686	switch (code) {
    687	/* dst = src */
    688	case BPF_ALU | BPF_MOV | BPF_X:
    689	case BPF_ALU64 | BPF_MOV | BPF_X:
    690		emit(A64_MOV(is64, dst, src), ctx);
    691		break;
    692	/* dst = dst OP src */
    693	case BPF_ALU | BPF_ADD | BPF_X:
    694	case BPF_ALU64 | BPF_ADD | BPF_X:
    695		emit(A64_ADD(is64, dst, dst, src), ctx);
    696		break;
    697	case BPF_ALU | BPF_SUB | BPF_X:
    698	case BPF_ALU64 | BPF_SUB | BPF_X:
    699		emit(A64_SUB(is64, dst, dst, src), ctx);
    700		break;
    701	case BPF_ALU | BPF_AND | BPF_X:
    702	case BPF_ALU64 | BPF_AND | BPF_X:
    703		emit(A64_AND(is64, dst, dst, src), ctx);
    704		break;
    705	case BPF_ALU | BPF_OR | BPF_X:
    706	case BPF_ALU64 | BPF_OR | BPF_X:
    707		emit(A64_ORR(is64, dst, dst, src), ctx);
    708		break;
    709	case BPF_ALU | BPF_XOR | BPF_X:
    710	case BPF_ALU64 | BPF_XOR | BPF_X:
    711		emit(A64_EOR(is64, dst, dst, src), ctx);
    712		break;
    713	case BPF_ALU | BPF_MUL | BPF_X:
    714	case BPF_ALU64 | BPF_MUL | BPF_X:
    715		emit(A64_MUL(is64, dst, dst, src), ctx);
    716		break;
    717	case BPF_ALU | BPF_DIV | BPF_X:
    718	case BPF_ALU64 | BPF_DIV | BPF_X:
    719		emit(A64_UDIV(is64, dst, dst, src), ctx);
    720		break;
    721	case BPF_ALU | BPF_MOD | BPF_X:
    722	case BPF_ALU64 | BPF_MOD | BPF_X:
    723		emit(A64_UDIV(is64, tmp, dst, src), ctx);
    724		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
    725		break;
    726	case BPF_ALU | BPF_LSH | BPF_X:
    727	case BPF_ALU64 | BPF_LSH | BPF_X:
    728		emit(A64_LSLV(is64, dst, dst, src), ctx);
    729		break;
    730	case BPF_ALU | BPF_RSH | BPF_X:
    731	case BPF_ALU64 | BPF_RSH | BPF_X:
    732		emit(A64_LSRV(is64, dst, dst, src), ctx);
    733		break;
    734	case BPF_ALU | BPF_ARSH | BPF_X:
    735	case BPF_ALU64 | BPF_ARSH | BPF_X:
    736		emit(A64_ASRV(is64, dst, dst, src), ctx);
    737		break;
    738	/* dst = -dst */
    739	case BPF_ALU | BPF_NEG:
    740	case BPF_ALU64 | BPF_NEG:
    741		emit(A64_NEG(is64, dst, dst), ctx);
    742		break;
    743	/* dst = BSWAP##imm(dst) */
    744	case BPF_ALU | BPF_END | BPF_FROM_LE:
    745	case BPF_ALU | BPF_END | BPF_FROM_BE:
    746#ifdef CONFIG_CPU_BIG_ENDIAN
    747		if (BPF_SRC(code) == BPF_FROM_BE)
    748			goto emit_bswap_uxt;
    749#else /* !CONFIG_CPU_BIG_ENDIAN */
    750		if (BPF_SRC(code) == BPF_FROM_LE)
    751			goto emit_bswap_uxt;
    752#endif
    753		switch (imm) {
    754		case 16:
    755			emit(A64_REV16(is64, dst, dst), ctx);
    756			/* zero-extend 16 bits into 64 bits */
    757			emit(A64_UXTH(is64, dst, dst), ctx);
    758			break;
    759		case 32:
    760			emit(A64_REV32(is64, dst, dst), ctx);
    761			/* upper 32 bits already cleared */
    762			break;
    763		case 64:
    764			emit(A64_REV64(dst, dst), ctx);
    765			break;
    766		}
    767		break;
    768emit_bswap_uxt:
    769		switch (imm) {
    770		case 16:
    771			/* zero-extend 16 bits into 64 bits */
    772			emit(A64_UXTH(is64, dst, dst), ctx);
    773			break;
    774		case 32:
    775			/* zero-extend 32 bits into 64 bits */
    776			emit(A64_UXTW(is64, dst, dst), ctx);
    777			break;
    778		case 64:
    779			/* nop */
    780			break;
    781		}
    782		break;
    783	/* dst = imm */
    784	case BPF_ALU | BPF_MOV | BPF_K:
    785	case BPF_ALU64 | BPF_MOV | BPF_K:
    786		emit_a64_mov_i(is64, dst, imm, ctx);
    787		break;
    788	/* dst = dst OP imm */
    789	case BPF_ALU | BPF_ADD | BPF_K:
    790	case BPF_ALU64 | BPF_ADD | BPF_K:
    791		if (is_addsub_imm(imm)) {
    792			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
    793		} else if (is_addsub_imm(-imm)) {
    794			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
    795		} else {
    796			emit_a64_mov_i(is64, tmp, imm, ctx);
    797			emit(A64_ADD(is64, dst, dst, tmp), ctx);
    798		}
    799		break;
    800	case BPF_ALU | BPF_SUB | BPF_K:
    801	case BPF_ALU64 | BPF_SUB | BPF_K:
    802		if (is_addsub_imm(imm)) {
    803			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
    804		} else if (is_addsub_imm(-imm)) {
    805			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
    806		} else {
    807			emit_a64_mov_i(is64, tmp, imm, ctx);
    808			emit(A64_SUB(is64, dst, dst, tmp), ctx);
    809		}
    810		break;
    811	case BPF_ALU | BPF_AND | BPF_K:
    812	case BPF_ALU64 | BPF_AND | BPF_K:
    813		a64_insn = A64_AND_I(is64, dst, dst, imm);
    814		if (a64_insn != AARCH64_BREAK_FAULT) {
    815			emit(a64_insn, ctx);
    816		} else {
    817			emit_a64_mov_i(is64, tmp, imm, ctx);
    818			emit(A64_AND(is64, dst, dst, tmp), ctx);
    819		}
    820		break;
    821	case BPF_ALU | BPF_OR | BPF_K:
    822	case BPF_ALU64 | BPF_OR | BPF_K:
    823		a64_insn = A64_ORR_I(is64, dst, dst, imm);
    824		if (a64_insn != AARCH64_BREAK_FAULT) {
    825			emit(a64_insn, ctx);
    826		} else {
    827			emit_a64_mov_i(is64, tmp, imm, ctx);
    828			emit(A64_ORR(is64, dst, dst, tmp), ctx);
    829		}
    830		break;
    831	case BPF_ALU | BPF_XOR | BPF_K:
    832	case BPF_ALU64 | BPF_XOR | BPF_K:
    833		a64_insn = A64_EOR_I(is64, dst, dst, imm);
    834		if (a64_insn != AARCH64_BREAK_FAULT) {
    835			emit(a64_insn, ctx);
    836		} else {
    837			emit_a64_mov_i(is64, tmp, imm, ctx);
    838			emit(A64_EOR(is64, dst, dst, tmp), ctx);
    839		}
    840		break;
    841	case BPF_ALU | BPF_MUL | BPF_K:
    842	case BPF_ALU64 | BPF_MUL | BPF_K:
    843		emit_a64_mov_i(is64, tmp, imm, ctx);
    844		emit(A64_MUL(is64, dst, dst, tmp), ctx);
    845		break;
    846	case BPF_ALU | BPF_DIV | BPF_K:
    847	case BPF_ALU64 | BPF_DIV | BPF_K:
    848		emit_a64_mov_i(is64, tmp, imm, ctx);
    849		emit(A64_UDIV(is64, dst, dst, tmp), ctx);
    850		break;
    851	case BPF_ALU | BPF_MOD | BPF_K:
    852	case BPF_ALU64 | BPF_MOD | BPF_K:
    853		emit_a64_mov_i(is64, tmp2, imm, ctx);
    854		emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
    855		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
    856		break;
    857	case BPF_ALU | BPF_LSH | BPF_K:
    858	case BPF_ALU64 | BPF_LSH | BPF_K:
    859		emit(A64_LSL(is64, dst, dst, imm), ctx);
    860		break;
    861	case BPF_ALU | BPF_RSH | BPF_K:
    862	case BPF_ALU64 | BPF_RSH | BPF_K:
    863		emit(A64_LSR(is64, dst, dst, imm), ctx);
    864		break;
    865	case BPF_ALU | BPF_ARSH | BPF_K:
    866	case BPF_ALU64 | BPF_ARSH | BPF_K:
    867		emit(A64_ASR(is64, dst, dst, imm), ctx);
    868		break;
    869
    870	/* JUMP off */
    871	case BPF_JMP | BPF_JA:
    872		jmp_offset = bpf2a64_offset(i, off, ctx);
    873		check_imm26(jmp_offset);
    874		emit(A64_B(jmp_offset), ctx);
    875		break;
    876	/* IF (dst COND src) JUMP off */
    877	case BPF_JMP | BPF_JEQ | BPF_X:
    878	case BPF_JMP | BPF_JGT | BPF_X:
    879	case BPF_JMP | BPF_JLT | BPF_X:
    880	case BPF_JMP | BPF_JGE | BPF_X:
    881	case BPF_JMP | BPF_JLE | BPF_X:
    882	case BPF_JMP | BPF_JNE | BPF_X:
    883	case BPF_JMP | BPF_JSGT | BPF_X:
    884	case BPF_JMP | BPF_JSLT | BPF_X:
    885	case BPF_JMP | BPF_JSGE | BPF_X:
    886	case BPF_JMP | BPF_JSLE | BPF_X:
    887	case BPF_JMP32 | BPF_JEQ | BPF_X:
    888	case BPF_JMP32 | BPF_JGT | BPF_X:
    889	case BPF_JMP32 | BPF_JLT | BPF_X:
    890	case BPF_JMP32 | BPF_JGE | BPF_X:
    891	case BPF_JMP32 | BPF_JLE | BPF_X:
    892	case BPF_JMP32 | BPF_JNE | BPF_X:
    893	case BPF_JMP32 | BPF_JSGT | BPF_X:
    894	case BPF_JMP32 | BPF_JSLT | BPF_X:
    895	case BPF_JMP32 | BPF_JSGE | BPF_X:
    896	case BPF_JMP32 | BPF_JSLE | BPF_X:
    897		emit(A64_CMP(is64, dst, src), ctx);
    898emit_cond_jmp:
    899		jmp_offset = bpf2a64_offset(i, off, ctx);
    900		check_imm19(jmp_offset);
    901		switch (BPF_OP(code)) {
    902		case BPF_JEQ:
    903			jmp_cond = A64_COND_EQ;
    904			break;
    905		case BPF_JGT:
    906			jmp_cond = A64_COND_HI;
    907			break;
    908		case BPF_JLT:
    909			jmp_cond = A64_COND_CC;
    910			break;
    911		case BPF_JGE:
    912			jmp_cond = A64_COND_CS;
    913			break;
    914		case BPF_JLE:
    915			jmp_cond = A64_COND_LS;
    916			break;
    917		case BPF_JSET:
    918		case BPF_JNE:
    919			jmp_cond = A64_COND_NE;
    920			break;
    921		case BPF_JSGT:
    922			jmp_cond = A64_COND_GT;
    923			break;
    924		case BPF_JSLT:
    925			jmp_cond = A64_COND_LT;
    926			break;
    927		case BPF_JSGE:
    928			jmp_cond = A64_COND_GE;
    929			break;
    930		case BPF_JSLE:
    931			jmp_cond = A64_COND_LE;
    932			break;
    933		default:
    934			return -EFAULT;
    935		}
    936		emit(A64_B_(jmp_cond, jmp_offset), ctx);
    937		break;
    938	case BPF_JMP | BPF_JSET | BPF_X:
    939	case BPF_JMP32 | BPF_JSET | BPF_X:
    940		emit(A64_TST(is64, dst, src), ctx);
    941		goto emit_cond_jmp;
    942	/* IF (dst COND imm) JUMP off */
    943	case BPF_JMP | BPF_JEQ | BPF_K:
    944	case BPF_JMP | BPF_JGT | BPF_K:
    945	case BPF_JMP | BPF_JLT | BPF_K:
    946	case BPF_JMP | BPF_JGE | BPF_K:
    947	case BPF_JMP | BPF_JLE | BPF_K:
    948	case BPF_JMP | BPF_JNE | BPF_K:
    949	case BPF_JMP | BPF_JSGT | BPF_K:
    950	case BPF_JMP | BPF_JSLT | BPF_K:
    951	case BPF_JMP | BPF_JSGE | BPF_K:
    952	case BPF_JMP | BPF_JSLE | BPF_K:
    953	case BPF_JMP32 | BPF_JEQ | BPF_K:
    954	case BPF_JMP32 | BPF_JGT | BPF_K:
    955	case BPF_JMP32 | BPF_JLT | BPF_K:
    956	case BPF_JMP32 | BPF_JGE | BPF_K:
    957	case BPF_JMP32 | BPF_JLE | BPF_K:
    958	case BPF_JMP32 | BPF_JNE | BPF_K:
    959	case BPF_JMP32 | BPF_JSGT | BPF_K:
    960	case BPF_JMP32 | BPF_JSLT | BPF_K:
    961	case BPF_JMP32 | BPF_JSGE | BPF_K:
    962	case BPF_JMP32 | BPF_JSLE | BPF_K:
    963		if (is_addsub_imm(imm)) {
    964			emit(A64_CMP_I(is64, dst, imm), ctx);
    965		} else if (is_addsub_imm(-imm)) {
    966			emit(A64_CMN_I(is64, dst, -imm), ctx);
    967		} else {
    968			emit_a64_mov_i(is64, tmp, imm, ctx);
    969			emit(A64_CMP(is64, dst, tmp), ctx);
    970		}
    971		goto emit_cond_jmp;
    972	case BPF_JMP | BPF_JSET | BPF_K:
    973	case BPF_JMP32 | BPF_JSET | BPF_K:
    974		a64_insn = A64_TST_I(is64, dst, imm);
    975		if (a64_insn != AARCH64_BREAK_FAULT) {
    976			emit(a64_insn, ctx);
    977		} else {
    978			emit_a64_mov_i(is64, tmp, imm, ctx);
    979			emit(A64_TST(is64, dst, tmp), ctx);
    980		}
    981		goto emit_cond_jmp;
    982	/* function call */
    983	case BPF_JMP | BPF_CALL:
    984	{
    985		const u8 r0 = bpf2a64[BPF_REG_0];
    986		bool func_addr_fixed;
    987		u64 func_addr;
    988
    989		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
    990					    &func_addr, &func_addr_fixed);
    991		if (ret < 0)
    992			return ret;
    993		emit_addr_mov_i64(tmp, func_addr, ctx);
    994		emit(A64_BLR(tmp), ctx);
    995		emit(A64_MOV(1, r0, A64_R(0)), ctx);
    996		break;
    997	}
    998	/* tail call */
    999	case BPF_JMP | BPF_TAIL_CALL:
   1000		if (emit_bpf_tail_call(ctx))
   1001			return -EFAULT;
   1002		break;
   1003	/* function return */
   1004	case BPF_JMP | BPF_EXIT:
   1005		/* Optimization: when last instruction is EXIT,
   1006		   simply fallthrough to epilogue. */
   1007		if (i == ctx->prog->len - 1)
   1008			break;
   1009		jmp_offset = epilogue_offset(ctx);
   1010		check_imm26(jmp_offset);
   1011		emit(A64_B(jmp_offset), ctx);
   1012		break;
   1013
   1014	/* dst = imm64 */
   1015	case BPF_LD | BPF_IMM | BPF_DW:
   1016	{
   1017		const struct bpf_insn insn1 = insn[1];
   1018		u64 imm64;
   1019
   1020		imm64 = (u64)insn1.imm << 32 | (u32)imm;
   1021		if (bpf_pseudo_func(insn))
   1022			emit_addr_mov_i64(dst, imm64, ctx);
   1023		else
   1024			emit_a64_mov_i64(dst, imm64, ctx);
   1025
   1026		return 1;
   1027	}
   1028
   1029	/* LDX: dst = *(size *)(src + off) */
   1030	case BPF_LDX | BPF_MEM | BPF_W:
   1031	case BPF_LDX | BPF_MEM | BPF_H:
   1032	case BPF_LDX | BPF_MEM | BPF_B:
   1033	case BPF_LDX | BPF_MEM | BPF_DW:
   1034	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
   1035	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
   1036	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
   1037	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
   1038		if (ctx->fpb_offset > 0 && src == fp) {
   1039			src_adj = fpb;
   1040			off_adj = off + ctx->fpb_offset;
   1041		} else {
   1042			src_adj = src;
   1043			off_adj = off;
   1044		}
   1045		switch (BPF_SIZE(code)) {
   1046		case BPF_W:
   1047			if (is_lsi_offset(off_adj, 2)) {
   1048				emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
   1049			} else {
   1050				emit_a64_mov_i(1, tmp, off, ctx);
   1051				emit(A64_LDR32(dst, src, tmp), ctx);
   1052			}
   1053			break;
   1054		case BPF_H:
   1055			if (is_lsi_offset(off_adj, 1)) {
   1056				emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
   1057			} else {
   1058				emit_a64_mov_i(1, tmp, off, ctx);
   1059				emit(A64_LDRH(dst, src, tmp), ctx);
   1060			}
   1061			break;
   1062		case BPF_B:
   1063			if (is_lsi_offset(off_adj, 0)) {
   1064				emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
   1065			} else {
   1066				emit_a64_mov_i(1, tmp, off, ctx);
   1067				emit(A64_LDRB(dst, src, tmp), ctx);
   1068			}
   1069			break;
   1070		case BPF_DW:
   1071			if (is_lsi_offset(off_adj, 3)) {
   1072				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
   1073			} else {
   1074				emit_a64_mov_i(1, tmp, off, ctx);
   1075				emit(A64_LDR64(dst, src, tmp), ctx);
   1076			}
   1077			break;
   1078		}
   1079
   1080		ret = add_exception_handler(insn, ctx, dst);
   1081		if (ret)
   1082			return ret;
   1083		break;
   1084
   1085	/* speculation barrier */
   1086	case BPF_ST | BPF_NOSPEC:
   1087		/*
   1088		 * Nothing required here.
   1089		 *
   1090		 * In case of arm64, we rely on the firmware mitigation of
   1091		 * Speculative Store Bypass as controlled via the ssbd kernel
   1092		 * parameter. Whenever the mitigation is enabled, it works
   1093		 * for all of the kernel code with no need to provide any
   1094		 * additional instructions.
   1095		 */
   1096		break;
   1097
   1098	/* ST: *(size *)(dst + off) = imm */
   1099	case BPF_ST | BPF_MEM | BPF_W:
   1100	case BPF_ST | BPF_MEM | BPF_H:
   1101	case BPF_ST | BPF_MEM | BPF_B:
   1102	case BPF_ST | BPF_MEM | BPF_DW:
   1103		if (ctx->fpb_offset > 0 && dst == fp) {
   1104			dst_adj = fpb;
   1105			off_adj = off + ctx->fpb_offset;
   1106		} else {
   1107			dst_adj = dst;
   1108			off_adj = off;
   1109		}
   1110		/* Load imm to a register then store it */
   1111		emit_a64_mov_i(1, tmp, imm, ctx);
   1112		switch (BPF_SIZE(code)) {
   1113		case BPF_W:
   1114			if (is_lsi_offset(off_adj, 2)) {
   1115				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
   1116			} else {
   1117				emit_a64_mov_i(1, tmp2, off, ctx);
   1118				emit(A64_STR32(tmp, dst, tmp2), ctx);
   1119			}
   1120			break;
   1121		case BPF_H:
   1122			if (is_lsi_offset(off_adj, 1)) {
   1123				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
   1124			} else {
   1125				emit_a64_mov_i(1, tmp2, off, ctx);
   1126				emit(A64_STRH(tmp, dst, tmp2), ctx);
   1127			}
   1128			break;
   1129		case BPF_B:
   1130			if (is_lsi_offset(off_adj, 0)) {
   1131				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
   1132			} else {
   1133				emit_a64_mov_i(1, tmp2, off, ctx);
   1134				emit(A64_STRB(tmp, dst, tmp2), ctx);
   1135			}
   1136			break;
   1137		case BPF_DW:
   1138			if (is_lsi_offset(off_adj, 3)) {
   1139				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
   1140			} else {
   1141				emit_a64_mov_i(1, tmp2, off, ctx);
   1142				emit(A64_STR64(tmp, dst, tmp2), ctx);
   1143			}
   1144			break;
   1145		}
   1146		break;
   1147
   1148	/* STX: *(size *)(dst + off) = src */
   1149	case BPF_STX | BPF_MEM | BPF_W:
   1150	case BPF_STX | BPF_MEM | BPF_H:
   1151	case BPF_STX | BPF_MEM | BPF_B:
   1152	case BPF_STX | BPF_MEM | BPF_DW:
   1153		if (ctx->fpb_offset > 0 && dst == fp) {
   1154			dst_adj = fpb;
   1155			off_adj = off + ctx->fpb_offset;
   1156		} else {
   1157			dst_adj = dst;
   1158			off_adj = off;
   1159		}
   1160		switch (BPF_SIZE(code)) {
   1161		case BPF_W:
   1162			if (is_lsi_offset(off_adj, 2)) {
   1163				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
   1164			} else {
   1165				emit_a64_mov_i(1, tmp, off, ctx);
   1166				emit(A64_STR32(src, dst, tmp), ctx);
   1167			}
   1168			break;
   1169		case BPF_H:
   1170			if (is_lsi_offset(off_adj, 1)) {
   1171				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
   1172			} else {
   1173				emit_a64_mov_i(1, tmp, off, ctx);
   1174				emit(A64_STRH(src, dst, tmp), ctx);
   1175			}
   1176			break;
   1177		case BPF_B:
   1178			if (is_lsi_offset(off_adj, 0)) {
   1179				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
   1180			} else {
   1181				emit_a64_mov_i(1, tmp, off, ctx);
   1182				emit(A64_STRB(src, dst, tmp), ctx);
   1183			}
   1184			break;
   1185		case BPF_DW:
   1186			if (is_lsi_offset(off_adj, 3)) {
   1187				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
   1188			} else {
   1189				emit_a64_mov_i(1, tmp, off, ctx);
   1190				emit(A64_STR64(src, dst, tmp), ctx);
   1191			}
   1192			break;
   1193		}
   1194		break;
   1195
   1196	case BPF_STX | BPF_ATOMIC | BPF_W:
   1197	case BPF_STX | BPF_ATOMIC | BPF_DW:
   1198		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
   1199			ret = emit_lse_atomic(insn, ctx);
   1200		else
   1201			ret = emit_ll_sc_atomic(insn, ctx);
   1202		if (ret)
   1203			return ret;
   1204		break;
   1205
   1206	default:
   1207		pr_err_once("unknown opcode %02x\n", code);
   1208		return -EINVAL;
   1209	}
   1210
   1211	return 0;
   1212}
   1213
   1214/*
   1215 * Return 0 if FP may change at runtime, otherwise find the minimum negative
   1216 * offset to FP, converts it to positive number, and align down to 8 bytes.
   1217 */
   1218static int find_fpb_offset(struct bpf_prog *prog)
   1219{
   1220	int i;
   1221	int offset = 0;
   1222
   1223	for (i = 0; i < prog->len; i++) {
   1224		const struct bpf_insn *insn = &prog->insnsi[i];
   1225		const u8 class = BPF_CLASS(insn->code);
   1226		const u8 mode = BPF_MODE(insn->code);
   1227		const u8 src = insn->src_reg;
   1228		const u8 dst = insn->dst_reg;
   1229		const s32 imm = insn->imm;
   1230		const s16 off = insn->off;
   1231
   1232		switch (class) {
   1233		case BPF_STX:
   1234		case BPF_ST:
   1235			/* fp holds atomic operation result */
   1236			if (class == BPF_STX && mode == BPF_ATOMIC &&
   1237			    ((imm == BPF_XCHG ||
   1238			      imm == (BPF_FETCH | BPF_ADD) ||
   1239			      imm == (BPF_FETCH | BPF_AND) ||
   1240			      imm == (BPF_FETCH | BPF_XOR) ||
   1241			      imm == (BPF_FETCH | BPF_OR)) &&
   1242			     src == BPF_REG_FP))
   1243				return 0;
   1244
   1245			if (mode == BPF_MEM && dst == BPF_REG_FP &&
   1246			    off < offset)
   1247				offset = insn->off;
   1248			break;
   1249
   1250		case BPF_JMP32:
   1251		case BPF_JMP:
   1252			break;
   1253
   1254		case BPF_LDX:
   1255		case BPF_LD:
   1256			/* fp holds load result */
   1257			if (dst == BPF_REG_FP)
   1258				return 0;
   1259
   1260			if (class == BPF_LDX && mode == BPF_MEM &&
   1261			    src == BPF_REG_FP && off < offset)
   1262				offset = off;
   1263			break;
   1264
   1265		case BPF_ALU:
   1266		case BPF_ALU64:
   1267		default:
   1268			/* fp holds ALU result */
   1269			if (dst == BPF_REG_FP)
   1270				return 0;
   1271		}
   1272	}
   1273
   1274	if (offset < 0) {
   1275		/*
   1276		 * safely be converted to a positive 'int', since insn->off
   1277		 * is 's16'
   1278		 */
   1279		offset = -offset;
   1280		/* align down to 8 bytes */
   1281		offset = ALIGN_DOWN(offset, 8);
   1282	}
   1283
   1284	return offset;
   1285}
   1286
   1287static int build_body(struct jit_ctx *ctx, bool extra_pass)
   1288{
   1289	const struct bpf_prog *prog = ctx->prog;
   1290	int i;
   1291
   1292	/*
   1293	 * - offset[0] offset of the end of prologue,
   1294	 *   start of the 1st instruction.
   1295	 * - offset[1] - offset of the end of 1st instruction,
   1296	 *   start of the 2nd instruction
   1297	 * [....]
   1298	 * - offset[3] - offset of the end of 3rd instruction,
   1299	 *   start of 4th instruction
   1300	 */
   1301	for (i = 0; i < prog->len; i++) {
   1302		const struct bpf_insn *insn = &prog->insnsi[i];
   1303		int ret;
   1304
   1305		if (ctx->image == NULL)
   1306			ctx->offset[i] = ctx->idx;
   1307		ret = build_insn(insn, ctx, extra_pass);
   1308		if (ret > 0) {
   1309			i++;
   1310			if (ctx->image == NULL)
   1311				ctx->offset[i] = ctx->idx;
   1312			continue;
   1313		}
   1314		if (ret)
   1315			return ret;
   1316	}
   1317	/*
   1318	 * offset is allocated with prog->len + 1 so fill in
   1319	 * the last element with the offset after the last
   1320	 * instruction (end of program)
   1321	 */
   1322	if (ctx->image == NULL)
   1323		ctx->offset[i] = ctx->idx;
   1324
   1325	return 0;
   1326}
   1327
   1328static int validate_code(struct jit_ctx *ctx)
   1329{
   1330	int i;
   1331
   1332	for (i = 0; i < ctx->idx; i++) {
   1333		u32 a64_insn = le32_to_cpu(ctx->image[i]);
   1334
   1335		if (a64_insn == AARCH64_BREAK_FAULT)
   1336			return -1;
   1337	}
   1338
   1339	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
   1340		return -1;
   1341
   1342	return 0;
   1343}
   1344
   1345static inline void bpf_flush_icache(void *start, void *end)
   1346{
   1347	flush_icache_range((unsigned long)start, (unsigned long)end);
   1348}
   1349
   1350struct arm64_jit_data {
   1351	struct bpf_binary_header *header;
   1352	u8 *image;
   1353	struct jit_ctx ctx;
   1354};
   1355
   1356struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
   1357{
   1358	int image_size, prog_size, extable_size;
   1359	struct bpf_prog *tmp, *orig_prog = prog;
   1360	struct bpf_binary_header *header;
   1361	struct arm64_jit_data *jit_data;
   1362	bool was_classic = bpf_prog_was_classic(prog);
   1363	bool tmp_blinded = false;
   1364	bool extra_pass = false;
   1365	struct jit_ctx ctx;
   1366	u8 *image_ptr;
   1367
   1368	if (!prog->jit_requested)
   1369		return orig_prog;
   1370
   1371	tmp = bpf_jit_blind_constants(prog);
   1372	/* If blinding was requested and we failed during blinding,
   1373	 * we must fall back to the interpreter.
   1374	 */
   1375	if (IS_ERR(tmp))
   1376		return orig_prog;
   1377	if (tmp != prog) {
   1378		tmp_blinded = true;
   1379		prog = tmp;
   1380	}
   1381
   1382	jit_data = prog->aux->jit_data;
   1383	if (!jit_data) {
   1384		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
   1385		if (!jit_data) {
   1386			prog = orig_prog;
   1387			goto out;
   1388		}
   1389		prog->aux->jit_data = jit_data;
   1390	}
   1391	if (jit_data->ctx.offset) {
   1392		ctx = jit_data->ctx;
   1393		image_ptr = jit_data->image;
   1394		header = jit_data->header;
   1395		extra_pass = true;
   1396		prog_size = sizeof(u32) * ctx.idx;
   1397		goto skip_init_ctx;
   1398	}
   1399	memset(&ctx, 0, sizeof(ctx));
   1400	ctx.prog = prog;
   1401
   1402	ctx.offset = kcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
   1403	if (ctx.offset == NULL) {
   1404		prog = orig_prog;
   1405		goto out_off;
   1406	}
   1407
   1408	ctx.fpb_offset = find_fpb_offset(prog);
   1409
   1410	/*
   1411	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
   1412	 *
   1413	 * BPF line info needs ctx->offset[i] to be the offset of
   1414	 * instruction[i] in jited image, so build prologue first.
   1415	 */
   1416	if (build_prologue(&ctx, was_classic)) {
   1417		prog = orig_prog;
   1418		goto out_off;
   1419	}
   1420
   1421	if (build_body(&ctx, extra_pass)) {
   1422		prog = orig_prog;
   1423		goto out_off;
   1424	}
   1425
   1426	ctx.epilogue_offset = ctx.idx;
   1427	build_epilogue(&ctx);
   1428
   1429	extable_size = prog->aux->num_exentries *
   1430		sizeof(struct exception_table_entry);
   1431
   1432	/* Now we know the actual image size. */
   1433	prog_size = sizeof(u32) * ctx.idx;
   1434	image_size = prog_size + extable_size;
   1435	header = bpf_jit_binary_alloc(image_size, &image_ptr,
   1436				      sizeof(u32), jit_fill_hole);
   1437	if (header == NULL) {
   1438		prog = orig_prog;
   1439		goto out_off;
   1440	}
   1441
   1442	/* 2. Now, the actual pass. */
   1443
   1444	ctx.image = (__le32 *)image_ptr;
   1445	if (extable_size)
   1446		prog->aux->extable = (void *)image_ptr + prog_size;
   1447skip_init_ctx:
   1448	ctx.idx = 0;
   1449	ctx.exentry_idx = 0;
   1450
   1451	build_prologue(&ctx, was_classic);
   1452
   1453	if (build_body(&ctx, extra_pass)) {
   1454		bpf_jit_binary_free(header);
   1455		prog = orig_prog;
   1456		goto out_off;
   1457	}
   1458
   1459	build_epilogue(&ctx);
   1460
   1461	/* 3. Extra pass to validate JITed code. */
   1462	if (validate_code(&ctx)) {
   1463		bpf_jit_binary_free(header);
   1464		prog = orig_prog;
   1465		goto out_off;
   1466	}
   1467
   1468	/* And we're done. */
   1469	if (bpf_jit_enable > 1)
   1470		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
   1471
   1472	bpf_flush_icache(header, ctx.image + ctx.idx);
   1473
   1474	if (!prog->is_func || extra_pass) {
   1475		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
   1476			pr_err_once("multi-func JIT bug %d != %d\n",
   1477				    ctx.idx, jit_data->ctx.idx);
   1478			bpf_jit_binary_free(header);
   1479			prog->bpf_func = NULL;
   1480			prog->jited = 0;
   1481			prog->jited_len = 0;
   1482			goto out_off;
   1483		}
   1484		bpf_jit_binary_lock_ro(header);
   1485	} else {
   1486		jit_data->ctx = ctx;
   1487		jit_data->image = image_ptr;
   1488		jit_data->header = header;
   1489	}
   1490	prog->bpf_func = (void *)ctx.image;
   1491	prog->jited = 1;
   1492	prog->jited_len = prog_size;
   1493
   1494	if (!prog->is_func || extra_pass) {
   1495		int i;
   1496
   1497		/* offset[prog->len] is the size of program */
   1498		for (i = 0; i <= prog->len; i++)
   1499			ctx.offset[i] *= AARCH64_INSN_SIZE;
   1500		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
   1501out_off:
   1502		kfree(ctx.offset);
   1503		kfree(jit_data);
   1504		prog->aux->jit_data = NULL;
   1505	}
   1506out:
   1507	if (tmp_blinded)
   1508		bpf_jit_prog_release_other(prog, prog == orig_prog ?
   1509					   tmp : orig_prog);
   1510	return prog;
   1511}
   1512
   1513bool bpf_jit_supports_kfunc_call(void)
   1514{
   1515	return true;
   1516}
   1517
   1518u64 bpf_jit_alloc_exec_limit(void)
   1519{
   1520	return VMALLOC_END - VMALLOC_START;
   1521}
   1522
   1523void *bpf_jit_alloc_exec(unsigned long size)
   1524{
   1525	/* Memory is intended to be executable, reset the pointer tag. */
   1526	return kasan_reset_tag(vmalloc(size));
   1527}
   1528
   1529void bpf_jit_free_exec(void *addr)
   1530{
   1531	return vfree(addr);
   1532}