cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

bpf_jit_comp.c (68537B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * BPF JIT compiler
      4 *
      5 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
      6 * Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
      7 */
      8#include <linux/netdevice.h>
      9#include <linux/filter.h>
     10#include <linux/if_vlan.h>
     11#include <linux/bpf.h>
     12#include <linux/memory.h>
     13#include <linux/sort.h>
     14#include <asm/extable.h>
     15#include <asm/set_memory.h>
     16#include <asm/nospec-branch.h>
     17#include <asm/text-patching.h>
     18
     19static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
     20{
     21	if (len == 1)
     22		*ptr = bytes;
     23	else if (len == 2)
     24		*(u16 *)ptr = bytes;
     25	else {
     26		*(u32 *)ptr = bytes;
     27		barrier();
     28	}
     29	return ptr + len;
     30}
     31
     32#define EMIT(bytes, len) \
     33	do { prog = emit_code(prog, bytes, len); } while (0)
     34
     35#define EMIT1(b1)		EMIT(b1, 1)
     36#define EMIT2(b1, b2)		EMIT((b1) + ((b2) << 8), 2)
     37#define EMIT3(b1, b2, b3)	EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
     38#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
     39
     40#define EMIT1_off32(b1, off) \
     41	do { EMIT1(b1); EMIT(off, 4); } while (0)
     42#define EMIT2_off32(b1, b2, off) \
     43	do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
     44#define EMIT3_off32(b1, b2, b3, off) \
     45	do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
     46#define EMIT4_off32(b1, b2, b3, b4, off) \
     47	do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
     48
     49#ifdef CONFIG_X86_KERNEL_IBT
     50#define EMIT_ENDBR()	EMIT(gen_endbr(), 4)
     51#else
     52#define EMIT_ENDBR()
     53#endif
     54
     55static bool is_imm8(int value)
     56{
     57	return value <= 127 && value >= -128;
     58}
     59
     60static bool is_simm32(s64 value)
     61{
     62	return value == (s64)(s32)value;
     63}
     64
     65static bool is_uimm32(u64 value)
     66{
     67	return value == (u64)(u32)value;
     68}
     69
     70/* mov dst, src */
     71#define EMIT_mov(DST, SRC)								 \
     72	do {										 \
     73		if (DST != SRC)								 \
     74			EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
     75	} while (0)
     76
     77static int bpf_size_to_x86_bytes(int bpf_size)
     78{
     79	if (bpf_size == BPF_W)
     80		return 4;
     81	else if (bpf_size == BPF_H)
     82		return 2;
     83	else if (bpf_size == BPF_B)
     84		return 1;
     85	else if (bpf_size == BPF_DW)
     86		return 4; /* imm32 */
     87	else
     88		return 0;
     89}
     90
     91/*
     92 * List of x86 cond jumps opcodes (. + s8)
     93 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
     94 */
     95#define X86_JB  0x72
     96#define X86_JAE 0x73
     97#define X86_JE  0x74
     98#define X86_JNE 0x75
     99#define X86_JBE 0x76
    100#define X86_JA  0x77
    101#define X86_JL  0x7C
    102#define X86_JGE 0x7D
    103#define X86_JLE 0x7E
    104#define X86_JG  0x7F
    105
    106/* Pick a register outside of BPF range for JIT internal work */
    107#define AUX_REG (MAX_BPF_JIT_REG + 1)
    108#define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
    109
    110/*
    111 * The following table maps BPF registers to x86-64 registers.
    112 *
    113 * x86-64 register R12 is unused, since if used as base address
    114 * register in load/store instructions, it always needs an
    115 * extra byte of encoding and is callee saved.
    116 *
    117 * x86-64 register R9 is not used by BPF programs, but can be used by BPF
    118 * trampoline. x86-64 register R10 is used for blinding (if enabled).
    119 */
    120static const int reg2hex[] = {
    121	[BPF_REG_0] = 0,  /* RAX */
    122	[BPF_REG_1] = 7,  /* RDI */
    123	[BPF_REG_2] = 6,  /* RSI */
    124	[BPF_REG_3] = 2,  /* RDX */
    125	[BPF_REG_4] = 1,  /* RCX */
    126	[BPF_REG_5] = 0,  /* R8  */
    127	[BPF_REG_6] = 3,  /* RBX callee saved */
    128	[BPF_REG_7] = 5,  /* R13 callee saved */
    129	[BPF_REG_8] = 6,  /* R14 callee saved */
    130	[BPF_REG_9] = 7,  /* R15 callee saved */
    131	[BPF_REG_FP] = 5, /* RBP readonly */
    132	[BPF_REG_AX] = 2, /* R10 temp register */
    133	[AUX_REG] = 3,    /* R11 temp register */
    134	[X86_REG_R9] = 1, /* R9 register, 6th function argument */
    135};
    136
    137static const int reg2pt_regs[] = {
    138	[BPF_REG_0] = offsetof(struct pt_regs, ax),
    139	[BPF_REG_1] = offsetof(struct pt_regs, di),
    140	[BPF_REG_2] = offsetof(struct pt_regs, si),
    141	[BPF_REG_3] = offsetof(struct pt_regs, dx),
    142	[BPF_REG_4] = offsetof(struct pt_regs, cx),
    143	[BPF_REG_5] = offsetof(struct pt_regs, r8),
    144	[BPF_REG_6] = offsetof(struct pt_regs, bx),
    145	[BPF_REG_7] = offsetof(struct pt_regs, r13),
    146	[BPF_REG_8] = offsetof(struct pt_regs, r14),
    147	[BPF_REG_9] = offsetof(struct pt_regs, r15),
    148};
    149
    150/*
    151 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
    152 * which need extra byte of encoding.
    153 * rax,rcx,...,rbp have simpler encoding
    154 */
    155static bool is_ereg(u32 reg)
    156{
    157	return (1 << reg) & (BIT(BPF_REG_5) |
    158			     BIT(AUX_REG) |
    159			     BIT(BPF_REG_7) |
    160			     BIT(BPF_REG_8) |
    161			     BIT(BPF_REG_9) |
    162			     BIT(X86_REG_R9) |
    163			     BIT(BPF_REG_AX));
    164}
    165
    166/*
    167 * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64
    168 * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte
    169 * of encoding. al,cl,dl,bl have simpler encoding.
    170 */
    171static bool is_ereg_8l(u32 reg)
    172{
    173	return is_ereg(reg) ||
    174	    (1 << reg) & (BIT(BPF_REG_1) |
    175			  BIT(BPF_REG_2) |
    176			  BIT(BPF_REG_FP));
    177}
    178
    179static bool is_axreg(u32 reg)
    180{
    181	return reg == BPF_REG_0;
    182}
    183
    184/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
    185static u8 add_1mod(u8 byte, u32 reg)
    186{
    187	if (is_ereg(reg))
    188		byte |= 1;
    189	return byte;
    190}
    191
    192static u8 add_2mod(u8 byte, u32 r1, u32 r2)
    193{
    194	if (is_ereg(r1))
    195		byte |= 1;
    196	if (is_ereg(r2))
    197		byte |= 4;
    198	return byte;
    199}
    200
    201/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
    202static u8 add_1reg(u8 byte, u32 dst_reg)
    203{
    204	return byte + reg2hex[dst_reg];
    205}
    206
    207/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
    208static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
    209{
    210	return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
    211}
    212
    213/* Some 1-byte opcodes for binary ALU operations */
    214static u8 simple_alu_opcodes[] = {
    215	[BPF_ADD] = 0x01,
    216	[BPF_SUB] = 0x29,
    217	[BPF_AND] = 0x21,
    218	[BPF_OR] = 0x09,
    219	[BPF_XOR] = 0x31,
    220	[BPF_LSH] = 0xE0,
    221	[BPF_RSH] = 0xE8,
    222	[BPF_ARSH] = 0xF8,
    223};
    224
    225static void jit_fill_hole(void *area, unsigned int size)
    226{
    227	/* Fill whole space with INT3 instructions */
    228	memset(area, 0xcc, size);
    229}
    230
    231int bpf_arch_text_invalidate(void *dst, size_t len)
    232{
    233	return IS_ERR_OR_NULL(text_poke_set(dst, 0xcc, len));
    234}
    235
    236struct jit_context {
    237	int cleanup_addr; /* Epilogue code offset */
    238
    239	/*
    240	 * Program specific offsets of labels in the code; these rely on the
    241	 * JIT doing at least 2 passes, recording the position on the first
    242	 * pass, only to generate the correct offset on the second pass.
    243	 */
    244	int tail_call_direct_label;
    245	int tail_call_indirect_label;
    246};
    247
    248/* Maximum number of bytes emitted while JITing one eBPF insn */
    249#define BPF_MAX_INSN_SIZE	128
    250#define BPF_INSN_SAFETY		64
    251
    252/* Number of bytes emit_patch() needs to generate instructions */
    253#define X86_PATCH_SIZE		5
    254/* Number of bytes that will be skipped on tailcall */
    255#define X86_TAIL_CALL_OFFSET	(11 + ENDBR_INSN_SIZE)
    256
    257static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
    258{
    259	u8 *prog = *pprog;
    260
    261	if (callee_regs_used[0])
    262		EMIT1(0x53);         /* push rbx */
    263	if (callee_regs_used[1])
    264		EMIT2(0x41, 0x55);   /* push r13 */
    265	if (callee_regs_used[2])
    266		EMIT2(0x41, 0x56);   /* push r14 */
    267	if (callee_regs_used[3])
    268		EMIT2(0x41, 0x57);   /* push r15 */
    269	*pprog = prog;
    270}
    271
    272static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
    273{
    274	u8 *prog = *pprog;
    275
    276	if (callee_regs_used[3])
    277		EMIT2(0x41, 0x5F);   /* pop r15 */
    278	if (callee_regs_used[2])
    279		EMIT2(0x41, 0x5E);   /* pop r14 */
    280	if (callee_regs_used[1])
    281		EMIT2(0x41, 0x5D);   /* pop r13 */
    282	if (callee_regs_used[0])
    283		EMIT1(0x5B);         /* pop rbx */
    284	*pprog = prog;
    285}
    286
    287/*
    288 * Emit x86-64 prologue code for BPF program.
    289 * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
    290 * while jumping to another program
    291 */
    292static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
    293			  bool tail_call_reachable, bool is_subprog)
    294{
    295	u8 *prog = *pprog;
    296
    297	/* BPF trampoline can be made to work without these nops,
    298	 * but let's waste 5 bytes for now and optimize later
    299	 */
    300	EMIT_ENDBR();
    301	memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
    302	prog += X86_PATCH_SIZE;
    303	if (!ebpf_from_cbpf) {
    304		if (tail_call_reachable && !is_subprog)
    305			EMIT2(0x31, 0xC0); /* xor eax, eax */
    306		else
    307			EMIT2(0x66, 0x90); /* nop2 */
    308	}
    309	EMIT1(0x55);             /* push rbp */
    310	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
    311
    312	/* X86_TAIL_CALL_OFFSET is here */
    313	EMIT_ENDBR();
    314
    315	/* sub rsp, rounded_stack_depth */
    316	if (stack_depth)
    317		EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
    318	if (tail_call_reachable)
    319		EMIT1(0x50);         /* push rax */
    320	*pprog = prog;
    321}
    322
    323static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
    324{
    325	u8 *prog = *pprog;
    326	s64 offset;
    327
    328	offset = func - (ip + X86_PATCH_SIZE);
    329	if (!is_simm32(offset)) {
    330		pr_err("Target call %p is out of range\n", func);
    331		return -ERANGE;
    332	}
    333	EMIT1_off32(opcode, offset);
    334	*pprog = prog;
    335	return 0;
    336}
    337
    338static int emit_call(u8 **pprog, void *func, void *ip)
    339{
    340	return emit_patch(pprog, func, ip, 0xE8);
    341}
    342
    343static int emit_jump(u8 **pprog, void *func, void *ip)
    344{
    345	return emit_patch(pprog, func, ip, 0xE9);
    346}
    347
    348static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
    349				void *old_addr, void *new_addr)
    350{
    351	const u8 *nop_insn = x86_nops[5];
    352	u8 old_insn[X86_PATCH_SIZE];
    353	u8 new_insn[X86_PATCH_SIZE];
    354	u8 *prog;
    355	int ret;
    356
    357	memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
    358	if (old_addr) {
    359		prog = old_insn;
    360		ret = t == BPF_MOD_CALL ?
    361		      emit_call(&prog, old_addr, ip) :
    362		      emit_jump(&prog, old_addr, ip);
    363		if (ret)
    364			return ret;
    365	}
    366
    367	memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
    368	if (new_addr) {
    369		prog = new_insn;
    370		ret = t == BPF_MOD_CALL ?
    371		      emit_call(&prog, new_addr, ip) :
    372		      emit_jump(&prog, new_addr, ip);
    373		if (ret)
    374			return ret;
    375	}
    376
    377	ret = -EBUSY;
    378	mutex_lock(&text_mutex);
    379	if (memcmp(ip, old_insn, X86_PATCH_SIZE))
    380		goto out;
    381	ret = 1;
    382	if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
    383		text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
    384		ret = 0;
    385	}
    386out:
    387	mutex_unlock(&text_mutex);
    388	return ret;
    389}
    390
    391int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
    392		       void *old_addr, void *new_addr)
    393{
    394	if (!is_kernel_text((long)ip) &&
    395	    !is_bpf_text_address((long)ip))
    396		/* BPF poking in modules is not supported */
    397		return -EINVAL;
    398
    399	/*
    400	 * See emit_prologue(), for IBT builds the trampoline hook is preceded
    401	 * with an ENDBR instruction.
    402	 */
    403	if (is_endbr(*(u32 *)ip))
    404		ip += ENDBR_INSN_SIZE;
    405
    406	return __bpf_arch_text_poke(ip, t, old_addr, new_addr);
    407}
    408
    409#define EMIT_LFENCE()	EMIT3(0x0F, 0xAE, 0xE8)
    410
    411static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip)
    412{
    413	u8 *prog = *pprog;
    414
    415#ifdef CONFIG_RETPOLINE
    416	if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) {
    417		EMIT_LFENCE();
    418		EMIT2(0xFF, 0xE0 + reg);
    419	} else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) {
    420		OPTIMIZER_HIDE_VAR(reg);
    421		emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip);
    422	} else
    423#endif
    424	EMIT2(0xFF, 0xE0 + reg);
    425
    426	*pprog = prog;
    427}
    428
    429/*
    430 * Generate the following code:
    431 *
    432 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
    433 *   if (index >= array->map.max_entries)
    434 *     goto out;
    435 *   if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
    436 *     goto out;
    437 *   prog = array->ptrs[index];
    438 *   if (prog == NULL)
    439 *     goto out;
    440 *   goto *(prog->bpf_func + prologue_size);
    441 * out:
    442 */
    443static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
    444					u32 stack_depth, u8 *ip,
    445					struct jit_context *ctx)
    446{
    447	int tcc_off = -4 - round_up(stack_depth, 8);
    448	u8 *prog = *pprog, *start = *pprog;
    449	int offset;
    450
    451	/*
    452	 * rdi - pointer to ctx
    453	 * rsi - pointer to bpf_array
    454	 * rdx - index in bpf_array
    455	 */
    456
    457	/*
    458	 * if (index >= array->map.max_entries)
    459	 *	goto out;
    460	 */
    461	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
    462	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
    463	      offsetof(struct bpf_array, map.max_entries));
    464
    465	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
    466	EMIT2(X86_JBE, offset);                   /* jbe out */
    467
    468	/*
    469	 * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
    470	 *	goto out;
    471	 */
    472	EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
    473	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
    474
    475	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
    476	EMIT2(X86_JAE, offset);                   /* jae out */
    477	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
    478	EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
    479
    480	/* prog = array->ptrs[index]; */
    481	EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
    482		    offsetof(struct bpf_array, ptrs));
    483
    484	/*
    485	 * if (prog == NULL)
    486	 *	goto out;
    487	 */
    488	EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
    489
    490	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
    491	EMIT2(X86_JE, offset);                    /* je out */
    492
    493	pop_callee_regs(&prog, callee_regs_used);
    494
    495	EMIT1(0x58);                              /* pop rax */
    496	if (stack_depth)
    497		EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
    498			    round_up(stack_depth, 8));
    499
    500	/* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
    501	EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
    502	      offsetof(struct bpf_prog, bpf_func));
    503	EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
    504	      X86_TAIL_CALL_OFFSET);
    505	/*
    506	 * Now we're ready to jump into next BPF program
    507	 * rdi == ctx (1st arg)
    508	 * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
    509	 */
    510	emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start));
    511
    512	/* out: */
    513	ctx->tail_call_indirect_label = prog - start;
    514	*pprog = prog;
    515}
    516
    517static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
    518				      u8 **pprog, u8 *ip,
    519				      bool *callee_regs_used, u32 stack_depth,
    520				      struct jit_context *ctx)
    521{
    522	int tcc_off = -4 - round_up(stack_depth, 8);
    523	u8 *prog = *pprog, *start = *pprog;
    524	int offset;
    525
    526	/*
    527	 * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
    528	 *	goto out;
    529	 */
    530	EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
    531	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
    532
    533	offset = ctx->tail_call_direct_label - (prog + 2 - start);
    534	EMIT2(X86_JAE, offset);                       /* jae out */
    535	EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
    536	EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
    537
    538	poke->tailcall_bypass = ip + (prog - start);
    539	poke->adj_off = X86_TAIL_CALL_OFFSET;
    540	poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
    541	poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
    542
    543	emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
    544		  poke->tailcall_bypass);
    545
    546	pop_callee_regs(&prog, callee_regs_used);
    547	EMIT1(0x58);                                  /* pop rax */
    548	if (stack_depth)
    549		EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
    550
    551	memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
    552	prog += X86_PATCH_SIZE;
    553
    554	/* out: */
    555	ctx->tail_call_direct_label = prog - start;
    556
    557	*pprog = prog;
    558}
    559
    560static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
    561{
    562	struct bpf_jit_poke_descriptor *poke;
    563	struct bpf_array *array;
    564	struct bpf_prog *target;
    565	int i, ret;
    566
    567	for (i = 0; i < prog->aux->size_poke_tab; i++) {
    568		poke = &prog->aux->poke_tab[i];
    569		if (poke->aux && poke->aux != prog->aux)
    570			continue;
    571
    572		WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable));
    573
    574		if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
    575			continue;
    576
    577		array = container_of(poke->tail_call.map, struct bpf_array, map);
    578		mutex_lock(&array->aux->poke_mutex);
    579		target = array->ptrs[poke->tail_call.key];
    580		if (target) {
    581			ret = __bpf_arch_text_poke(poke->tailcall_target,
    582						   BPF_MOD_JUMP, NULL,
    583						   (u8 *)target->bpf_func +
    584						   poke->adj_off);
    585			BUG_ON(ret < 0);
    586			ret = __bpf_arch_text_poke(poke->tailcall_bypass,
    587						   BPF_MOD_JUMP,
    588						   (u8 *)poke->tailcall_target +
    589						   X86_PATCH_SIZE, NULL);
    590			BUG_ON(ret < 0);
    591		}
    592		WRITE_ONCE(poke->tailcall_target_stable, true);
    593		mutex_unlock(&array->aux->poke_mutex);
    594	}
    595}
    596
    597static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
    598			   u32 dst_reg, const u32 imm32)
    599{
    600	u8 *prog = *pprog;
    601	u8 b1, b2, b3;
    602
    603	/*
    604	 * Optimization: if imm32 is positive, use 'mov %eax, imm32'
    605	 * (which zero-extends imm32) to save 2 bytes.
    606	 */
    607	if (sign_propagate && (s32)imm32 < 0) {
    608		/* 'mov %rax, imm32' sign extends imm32 */
    609		b1 = add_1mod(0x48, dst_reg);
    610		b2 = 0xC7;
    611		b3 = 0xC0;
    612		EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
    613		goto done;
    614	}
    615
    616	/*
    617	 * Optimization: if imm32 is zero, use 'xor %eax, %eax'
    618	 * to save 3 bytes.
    619	 */
    620	if (imm32 == 0) {
    621		if (is_ereg(dst_reg))
    622			EMIT1(add_2mod(0x40, dst_reg, dst_reg));
    623		b2 = 0x31; /* xor */
    624		b3 = 0xC0;
    625		EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
    626		goto done;
    627	}
    628
    629	/* mov %eax, imm32 */
    630	if (is_ereg(dst_reg))
    631		EMIT1(add_1mod(0x40, dst_reg));
    632	EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
    633done:
    634	*pprog = prog;
    635}
    636
    637static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
    638			   const u32 imm32_hi, const u32 imm32_lo)
    639{
    640	u8 *prog = *pprog;
    641
    642	if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
    643		/*
    644		 * For emitting plain u32, where sign bit must not be
    645		 * propagated LLVM tends to load imm64 over mov32
    646		 * directly, so save couple of bytes by just doing
    647		 * 'mov %eax, imm32' instead.
    648		 */
    649		emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
    650	} else {
    651		/* movabsq %rax, imm64 */
    652		EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
    653		EMIT(imm32_lo, 4);
    654		EMIT(imm32_hi, 4);
    655	}
    656
    657	*pprog = prog;
    658}
    659
    660static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
    661{
    662	u8 *prog = *pprog;
    663
    664	if (is64) {
    665		/* mov dst, src */
    666		EMIT_mov(dst_reg, src_reg);
    667	} else {
    668		/* mov32 dst, src */
    669		if (is_ereg(dst_reg) || is_ereg(src_reg))
    670			EMIT1(add_2mod(0x40, dst_reg, src_reg));
    671		EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
    672	}
    673
    674	*pprog = prog;
    675}
    676
    677/* Emit the suffix (ModR/M etc) for addressing *(ptr_reg + off) and val_reg */
    678static void emit_insn_suffix(u8 **pprog, u32 ptr_reg, u32 val_reg, int off)
    679{
    680	u8 *prog = *pprog;
    681
    682	if (is_imm8(off)) {
    683		/* 1-byte signed displacement.
    684		 *
    685		 * If off == 0 we could skip this and save one extra byte, but
    686		 * special case of x86 R13 which always needs an offset is not
    687		 * worth the hassle
    688		 */
    689		EMIT2(add_2reg(0x40, ptr_reg, val_reg), off);
    690	} else {
    691		/* 4-byte signed displacement */
    692		EMIT1_off32(add_2reg(0x80, ptr_reg, val_reg), off);
    693	}
    694	*pprog = prog;
    695}
    696
    697/*
    698 * Emit a REX byte if it will be necessary to address these registers
    699 */
    700static void maybe_emit_mod(u8 **pprog, u32 dst_reg, u32 src_reg, bool is64)
    701{
    702	u8 *prog = *pprog;
    703
    704	if (is64)
    705		EMIT1(add_2mod(0x48, dst_reg, src_reg));
    706	else if (is_ereg(dst_reg) || is_ereg(src_reg))
    707		EMIT1(add_2mod(0x40, dst_reg, src_reg));
    708	*pprog = prog;
    709}
    710
    711/*
    712 * Similar version of maybe_emit_mod() for a single register
    713 */
    714static void maybe_emit_1mod(u8 **pprog, u32 reg, bool is64)
    715{
    716	u8 *prog = *pprog;
    717
    718	if (is64)
    719		EMIT1(add_1mod(0x48, reg));
    720	else if (is_ereg(reg))
    721		EMIT1(add_1mod(0x40, reg));
    722	*pprog = prog;
    723}
    724
    725/* LDX: dst_reg = *(u8*)(src_reg + off) */
    726static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
    727{
    728	u8 *prog = *pprog;
    729
    730	switch (size) {
    731	case BPF_B:
    732		/* Emit 'movzx rax, byte ptr [rax + off]' */
    733		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
    734		break;
    735	case BPF_H:
    736		/* Emit 'movzx rax, word ptr [rax + off]' */
    737		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
    738		break;
    739	case BPF_W:
    740		/* Emit 'mov eax, dword ptr [rax+0x14]' */
    741		if (is_ereg(dst_reg) || is_ereg(src_reg))
    742			EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
    743		else
    744			EMIT1(0x8B);
    745		break;
    746	case BPF_DW:
    747		/* Emit 'mov rax, qword ptr [rax+0x14]' */
    748		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
    749		break;
    750	}
    751	emit_insn_suffix(&prog, src_reg, dst_reg, off);
    752	*pprog = prog;
    753}
    754
    755/* STX: *(u8*)(dst_reg + off) = src_reg */
    756static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
    757{
    758	u8 *prog = *pprog;
    759
    760	switch (size) {
    761	case BPF_B:
    762		/* Emit 'mov byte ptr [rax + off], al' */
    763		if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
    764			/* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
    765			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
    766		else
    767			EMIT1(0x88);
    768		break;
    769	case BPF_H:
    770		if (is_ereg(dst_reg) || is_ereg(src_reg))
    771			EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
    772		else
    773			EMIT2(0x66, 0x89);
    774		break;
    775	case BPF_W:
    776		if (is_ereg(dst_reg) || is_ereg(src_reg))
    777			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
    778		else
    779			EMIT1(0x89);
    780		break;
    781	case BPF_DW:
    782		EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
    783		break;
    784	}
    785	emit_insn_suffix(&prog, dst_reg, src_reg, off);
    786	*pprog = prog;
    787}
    788
    789static int emit_atomic(u8 **pprog, u8 atomic_op,
    790		       u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size)
    791{
    792	u8 *prog = *pprog;
    793
    794	EMIT1(0xF0); /* lock prefix */
    795
    796	maybe_emit_mod(&prog, dst_reg, src_reg, bpf_size == BPF_DW);
    797
    798	/* emit opcode */
    799	switch (atomic_op) {
    800	case BPF_ADD:
    801	case BPF_AND:
    802	case BPF_OR:
    803	case BPF_XOR:
    804		/* lock *(u32/u64*)(dst_reg + off) <op>= src_reg */
    805		EMIT1(simple_alu_opcodes[atomic_op]);
    806		break;
    807	case BPF_ADD | BPF_FETCH:
    808		/* src_reg = atomic_fetch_add(dst_reg + off, src_reg); */
    809		EMIT2(0x0F, 0xC1);
    810		break;
    811	case BPF_XCHG:
    812		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
    813		EMIT1(0x87);
    814		break;
    815	case BPF_CMPXCHG:
    816		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
    817		EMIT2(0x0F, 0xB1);
    818		break;
    819	default:
    820		pr_err("bpf_jit: unknown atomic opcode %02x\n", atomic_op);
    821		return -EFAULT;
    822	}
    823
    824	emit_insn_suffix(&prog, dst_reg, src_reg, off);
    825
    826	*pprog = prog;
    827	return 0;
    828}
    829
    830bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs)
    831{
    832	u32 reg = x->fixup >> 8;
    833
    834	/* jump over faulting load and clear dest register */
    835	*(unsigned long *)((void *)regs + reg) = 0;
    836	regs->ip += x->fixup & 0xff;
    837	return true;
    838}
    839
    840static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
    841			     bool *regs_used, bool *tail_call_seen)
    842{
    843	int i;
    844
    845	for (i = 1; i <= insn_cnt; i++, insn++) {
    846		if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
    847			*tail_call_seen = true;
    848		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
    849			regs_used[0] = true;
    850		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
    851			regs_used[1] = true;
    852		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
    853			regs_used[2] = true;
    854		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
    855			regs_used[3] = true;
    856	}
    857}
    858
    859static void emit_nops(u8 **pprog, int len)
    860{
    861	u8 *prog = *pprog;
    862	int i, noplen;
    863
    864	while (len > 0) {
    865		noplen = len;
    866
    867		if (noplen > ASM_NOP_MAX)
    868			noplen = ASM_NOP_MAX;
    869
    870		for (i = 0; i < noplen; i++)
    871			EMIT1(x86_nops[noplen][i]);
    872		len -= noplen;
    873	}
    874
    875	*pprog = prog;
    876}
    877
    878#define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
    879
    880static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
    881		  int oldproglen, struct jit_context *ctx, bool jmp_padding)
    882{
    883	bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
    884	struct bpf_insn *insn = bpf_prog->insnsi;
    885	bool callee_regs_used[4] = {};
    886	int insn_cnt = bpf_prog->len;
    887	bool tail_call_seen = false;
    888	bool seen_exit = false;
    889	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
    890	int i, excnt = 0;
    891	int ilen, proglen = 0;
    892	u8 *prog = temp;
    893	int err;
    894
    895	detect_reg_usage(insn, insn_cnt, callee_regs_used,
    896			 &tail_call_seen);
    897
    898	/* tail call's presence in current prog implies it is reachable */
    899	tail_call_reachable |= tail_call_seen;
    900
    901	emit_prologue(&prog, bpf_prog->aux->stack_depth,
    902		      bpf_prog_was_classic(bpf_prog), tail_call_reachable,
    903		      bpf_prog->aux->func_idx != 0);
    904	push_callee_regs(&prog, callee_regs_used);
    905
    906	ilen = prog - temp;
    907	if (rw_image)
    908		memcpy(rw_image + proglen, temp, ilen);
    909	proglen += ilen;
    910	addrs[0] = proglen;
    911	prog = temp;
    912
    913	for (i = 1; i <= insn_cnt; i++, insn++) {
    914		const s32 imm32 = insn->imm;
    915		u32 dst_reg = insn->dst_reg;
    916		u32 src_reg = insn->src_reg;
    917		u8 b2 = 0, b3 = 0;
    918		u8 *start_of_ldx;
    919		s64 jmp_offset;
    920		u8 jmp_cond;
    921		u8 *func;
    922		int nops;
    923
    924		switch (insn->code) {
    925			/* ALU */
    926		case BPF_ALU | BPF_ADD | BPF_X:
    927		case BPF_ALU | BPF_SUB | BPF_X:
    928		case BPF_ALU | BPF_AND | BPF_X:
    929		case BPF_ALU | BPF_OR | BPF_X:
    930		case BPF_ALU | BPF_XOR | BPF_X:
    931		case BPF_ALU64 | BPF_ADD | BPF_X:
    932		case BPF_ALU64 | BPF_SUB | BPF_X:
    933		case BPF_ALU64 | BPF_AND | BPF_X:
    934		case BPF_ALU64 | BPF_OR | BPF_X:
    935		case BPF_ALU64 | BPF_XOR | BPF_X:
    936			maybe_emit_mod(&prog, dst_reg, src_reg,
    937				       BPF_CLASS(insn->code) == BPF_ALU64);
    938			b2 = simple_alu_opcodes[BPF_OP(insn->code)];
    939			EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
    940			break;
    941
    942		case BPF_ALU64 | BPF_MOV | BPF_X:
    943		case BPF_ALU | BPF_MOV | BPF_X:
    944			emit_mov_reg(&prog,
    945				     BPF_CLASS(insn->code) == BPF_ALU64,
    946				     dst_reg, src_reg);
    947			break;
    948
    949			/* neg dst */
    950		case BPF_ALU | BPF_NEG:
    951		case BPF_ALU64 | BPF_NEG:
    952			maybe_emit_1mod(&prog, dst_reg,
    953					BPF_CLASS(insn->code) == BPF_ALU64);
    954			EMIT2(0xF7, add_1reg(0xD8, dst_reg));
    955			break;
    956
    957		case BPF_ALU | BPF_ADD | BPF_K:
    958		case BPF_ALU | BPF_SUB | BPF_K:
    959		case BPF_ALU | BPF_AND | BPF_K:
    960		case BPF_ALU | BPF_OR | BPF_K:
    961		case BPF_ALU | BPF_XOR | BPF_K:
    962		case BPF_ALU64 | BPF_ADD | BPF_K:
    963		case BPF_ALU64 | BPF_SUB | BPF_K:
    964		case BPF_ALU64 | BPF_AND | BPF_K:
    965		case BPF_ALU64 | BPF_OR | BPF_K:
    966		case BPF_ALU64 | BPF_XOR | BPF_K:
    967			maybe_emit_1mod(&prog, dst_reg,
    968					BPF_CLASS(insn->code) == BPF_ALU64);
    969
    970			/*
    971			 * b3 holds 'normal' opcode, b2 short form only valid
    972			 * in case dst is eax/rax.
    973			 */
    974			switch (BPF_OP(insn->code)) {
    975			case BPF_ADD:
    976				b3 = 0xC0;
    977				b2 = 0x05;
    978				break;
    979			case BPF_SUB:
    980				b3 = 0xE8;
    981				b2 = 0x2D;
    982				break;
    983			case BPF_AND:
    984				b3 = 0xE0;
    985				b2 = 0x25;
    986				break;
    987			case BPF_OR:
    988				b3 = 0xC8;
    989				b2 = 0x0D;
    990				break;
    991			case BPF_XOR:
    992				b3 = 0xF0;
    993				b2 = 0x35;
    994				break;
    995			}
    996
    997			if (is_imm8(imm32))
    998				EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
    999			else if (is_axreg(dst_reg))
   1000				EMIT1_off32(b2, imm32);
   1001			else
   1002				EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
   1003			break;
   1004
   1005		case BPF_ALU64 | BPF_MOV | BPF_K:
   1006		case BPF_ALU | BPF_MOV | BPF_K:
   1007			emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
   1008				       dst_reg, imm32);
   1009			break;
   1010
   1011		case BPF_LD | BPF_IMM | BPF_DW:
   1012			emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
   1013			insn++;
   1014			i++;
   1015			break;
   1016
   1017			/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
   1018		case BPF_ALU | BPF_MOD | BPF_X:
   1019		case BPF_ALU | BPF_DIV | BPF_X:
   1020		case BPF_ALU | BPF_MOD | BPF_K:
   1021		case BPF_ALU | BPF_DIV | BPF_K:
   1022		case BPF_ALU64 | BPF_MOD | BPF_X:
   1023		case BPF_ALU64 | BPF_DIV | BPF_X:
   1024		case BPF_ALU64 | BPF_MOD | BPF_K:
   1025		case BPF_ALU64 | BPF_DIV | BPF_K: {
   1026			bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
   1027
   1028			if (dst_reg != BPF_REG_0)
   1029				EMIT1(0x50); /* push rax */
   1030			if (dst_reg != BPF_REG_3)
   1031				EMIT1(0x52); /* push rdx */
   1032
   1033			if (BPF_SRC(insn->code) == BPF_X) {
   1034				if (src_reg == BPF_REG_0 ||
   1035				    src_reg == BPF_REG_3) {
   1036					/* mov r11, src_reg */
   1037					EMIT_mov(AUX_REG, src_reg);
   1038					src_reg = AUX_REG;
   1039				}
   1040			} else {
   1041				/* mov r11, imm32 */
   1042				EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
   1043				src_reg = AUX_REG;
   1044			}
   1045
   1046			if (dst_reg != BPF_REG_0)
   1047				/* mov rax, dst_reg */
   1048				emit_mov_reg(&prog, is64, BPF_REG_0, dst_reg);
   1049
   1050			/*
   1051			 * xor edx, edx
   1052			 * equivalent to 'xor rdx, rdx', but one byte less
   1053			 */
   1054			EMIT2(0x31, 0xd2);
   1055
   1056			/* div src_reg */
   1057			maybe_emit_1mod(&prog, src_reg, is64);
   1058			EMIT2(0xF7, add_1reg(0xF0, src_reg));
   1059
   1060			if (BPF_OP(insn->code) == BPF_MOD &&
   1061			    dst_reg != BPF_REG_3)
   1062				/* mov dst_reg, rdx */
   1063				emit_mov_reg(&prog, is64, dst_reg, BPF_REG_3);
   1064			else if (BPF_OP(insn->code) == BPF_DIV &&
   1065				 dst_reg != BPF_REG_0)
   1066				/* mov dst_reg, rax */
   1067				emit_mov_reg(&prog, is64, dst_reg, BPF_REG_0);
   1068
   1069			if (dst_reg != BPF_REG_3)
   1070				EMIT1(0x5A); /* pop rdx */
   1071			if (dst_reg != BPF_REG_0)
   1072				EMIT1(0x58); /* pop rax */
   1073			break;
   1074		}
   1075
   1076		case BPF_ALU | BPF_MUL | BPF_K:
   1077		case BPF_ALU64 | BPF_MUL | BPF_K:
   1078			maybe_emit_mod(&prog, dst_reg, dst_reg,
   1079				       BPF_CLASS(insn->code) == BPF_ALU64);
   1080
   1081			if (is_imm8(imm32))
   1082				/* imul dst_reg, dst_reg, imm8 */
   1083				EMIT3(0x6B, add_2reg(0xC0, dst_reg, dst_reg),
   1084				      imm32);
   1085			else
   1086				/* imul dst_reg, dst_reg, imm32 */
   1087				EMIT2_off32(0x69,
   1088					    add_2reg(0xC0, dst_reg, dst_reg),
   1089					    imm32);
   1090			break;
   1091
   1092		case BPF_ALU | BPF_MUL | BPF_X:
   1093		case BPF_ALU64 | BPF_MUL | BPF_X:
   1094			maybe_emit_mod(&prog, src_reg, dst_reg,
   1095				       BPF_CLASS(insn->code) == BPF_ALU64);
   1096
   1097			/* imul dst_reg, src_reg */
   1098			EMIT3(0x0F, 0xAF, add_2reg(0xC0, src_reg, dst_reg));
   1099			break;
   1100
   1101			/* Shifts */
   1102		case BPF_ALU | BPF_LSH | BPF_K:
   1103		case BPF_ALU | BPF_RSH | BPF_K:
   1104		case BPF_ALU | BPF_ARSH | BPF_K:
   1105		case BPF_ALU64 | BPF_LSH | BPF_K:
   1106		case BPF_ALU64 | BPF_RSH | BPF_K:
   1107		case BPF_ALU64 | BPF_ARSH | BPF_K:
   1108			maybe_emit_1mod(&prog, dst_reg,
   1109					BPF_CLASS(insn->code) == BPF_ALU64);
   1110
   1111			b3 = simple_alu_opcodes[BPF_OP(insn->code)];
   1112			if (imm32 == 1)
   1113				EMIT2(0xD1, add_1reg(b3, dst_reg));
   1114			else
   1115				EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
   1116			break;
   1117
   1118		case BPF_ALU | BPF_LSH | BPF_X:
   1119		case BPF_ALU | BPF_RSH | BPF_X:
   1120		case BPF_ALU | BPF_ARSH | BPF_X:
   1121		case BPF_ALU64 | BPF_LSH | BPF_X:
   1122		case BPF_ALU64 | BPF_RSH | BPF_X:
   1123		case BPF_ALU64 | BPF_ARSH | BPF_X:
   1124
   1125			/* Check for bad case when dst_reg == rcx */
   1126			if (dst_reg == BPF_REG_4) {
   1127				/* mov r11, dst_reg */
   1128				EMIT_mov(AUX_REG, dst_reg);
   1129				dst_reg = AUX_REG;
   1130			}
   1131
   1132			if (src_reg != BPF_REG_4) { /* common case */
   1133				EMIT1(0x51); /* push rcx */
   1134
   1135				/* mov rcx, src_reg */
   1136				EMIT_mov(BPF_REG_4, src_reg);
   1137			}
   1138
   1139			/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
   1140			maybe_emit_1mod(&prog, dst_reg,
   1141					BPF_CLASS(insn->code) == BPF_ALU64);
   1142
   1143			b3 = simple_alu_opcodes[BPF_OP(insn->code)];
   1144			EMIT2(0xD3, add_1reg(b3, dst_reg));
   1145
   1146			if (src_reg != BPF_REG_4)
   1147				EMIT1(0x59); /* pop rcx */
   1148
   1149			if (insn->dst_reg == BPF_REG_4)
   1150				/* mov dst_reg, r11 */
   1151				EMIT_mov(insn->dst_reg, AUX_REG);
   1152			break;
   1153
   1154		case BPF_ALU | BPF_END | BPF_FROM_BE:
   1155			switch (imm32) {
   1156			case 16:
   1157				/* Emit 'ror %ax, 8' to swap lower 2 bytes */
   1158				EMIT1(0x66);
   1159				if (is_ereg(dst_reg))
   1160					EMIT1(0x41);
   1161				EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
   1162
   1163				/* Emit 'movzwl eax, ax' */
   1164				if (is_ereg(dst_reg))
   1165					EMIT3(0x45, 0x0F, 0xB7);
   1166				else
   1167					EMIT2(0x0F, 0xB7);
   1168				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
   1169				break;
   1170			case 32:
   1171				/* Emit 'bswap eax' to swap lower 4 bytes */
   1172				if (is_ereg(dst_reg))
   1173					EMIT2(0x41, 0x0F);
   1174				else
   1175					EMIT1(0x0F);
   1176				EMIT1(add_1reg(0xC8, dst_reg));
   1177				break;
   1178			case 64:
   1179				/* Emit 'bswap rax' to swap 8 bytes */
   1180				EMIT3(add_1mod(0x48, dst_reg), 0x0F,
   1181				      add_1reg(0xC8, dst_reg));
   1182				break;
   1183			}
   1184			break;
   1185
   1186		case BPF_ALU | BPF_END | BPF_FROM_LE:
   1187			switch (imm32) {
   1188			case 16:
   1189				/*
   1190				 * Emit 'movzwl eax, ax' to zero extend 16-bit
   1191				 * into 64 bit
   1192				 */
   1193				if (is_ereg(dst_reg))
   1194					EMIT3(0x45, 0x0F, 0xB7);
   1195				else
   1196					EMIT2(0x0F, 0xB7);
   1197				EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
   1198				break;
   1199			case 32:
   1200				/* Emit 'mov eax, eax' to clear upper 32-bits */
   1201				if (is_ereg(dst_reg))
   1202					EMIT1(0x45);
   1203				EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
   1204				break;
   1205			case 64:
   1206				/* nop */
   1207				break;
   1208			}
   1209			break;
   1210
   1211			/* speculation barrier */
   1212		case BPF_ST | BPF_NOSPEC:
   1213			if (boot_cpu_has(X86_FEATURE_XMM2))
   1214				EMIT_LFENCE();
   1215			break;
   1216
   1217			/* ST: *(u8*)(dst_reg + off) = imm */
   1218		case BPF_ST | BPF_MEM | BPF_B:
   1219			if (is_ereg(dst_reg))
   1220				EMIT2(0x41, 0xC6);
   1221			else
   1222				EMIT1(0xC6);
   1223			goto st;
   1224		case BPF_ST | BPF_MEM | BPF_H:
   1225			if (is_ereg(dst_reg))
   1226				EMIT3(0x66, 0x41, 0xC7);
   1227			else
   1228				EMIT2(0x66, 0xC7);
   1229			goto st;
   1230		case BPF_ST | BPF_MEM | BPF_W:
   1231			if (is_ereg(dst_reg))
   1232				EMIT2(0x41, 0xC7);
   1233			else
   1234				EMIT1(0xC7);
   1235			goto st;
   1236		case BPF_ST | BPF_MEM | BPF_DW:
   1237			EMIT2(add_1mod(0x48, dst_reg), 0xC7);
   1238
   1239st:			if (is_imm8(insn->off))
   1240				EMIT2(add_1reg(0x40, dst_reg), insn->off);
   1241			else
   1242				EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
   1243
   1244			EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
   1245			break;
   1246
   1247			/* STX: *(u8*)(dst_reg + off) = src_reg */
   1248		case BPF_STX | BPF_MEM | BPF_B:
   1249		case BPF_STX | BPF_MEM | BPF_H:
   1250		case BPF_STX | BPF_MEM | BPF_W:
   1251		case BPF_STX | BPF_MEM | BPF_DW:
   1252			emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
   1253			break;
   1254
   1255			/* LDX: dst_reg = *(u8*)(src_reg + off) */
   1256		case BPF_LDX | BPF_MEM | BPF_B:
   1257		case BPF_LDX | BPF_PROBE_MEM | BPF_B:
   1258		case BPF_LDX | BPF_MEM | BPF_H:
   1259		case BPF_LDX | BPF_PROBE_MEM | BPF_H:
   1260		case BPF_LDX | BPF_MEM | BPF_W:
   1261		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
   1262		case BPF_LDX | BPF_MEM | BPF_DW:
   1263		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
   1264			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
   1265				/* Though the verifier prevents negative insn->off in BPF_PROBE_MEM
   1266				 * add abs(insn->off) to the limit to make sure that negative
   1267				 * offset won't be an issue.
   1268				 * insn->off is s16, so it won't affect valid pointers.
   1269				 */
   1270				u64 limit = TASK_SIZE_MAX + PAGE_SIZE + abs(insn->off);
   1271				u8 *end_of_jmp1, *end_of_jmp2;
   1272
   1273				/* Conservatively check that src_reg + insn->off is a kernel address:
   1274				 * 1. src_reg + insn->off >= limit
   1275				 * 2. src_reg + insn->off doesn't become small positive.
   1276				 * Cannot do src_reg + insn->off >= limit in one branch,
   1277				 * since it needs two spare registers, but JIT has only one.
   1278				 */
   1279
   1280				/* movabsq r11, limit */
   1281				EMIT2(add_1mod(0x48, AUX_REG), add_1reg(0xB8, AUX_REG));
   1282				EMIT((u32)limit, 4);
   1283				EMIT(limit >> 32, 4);
   1284				/* cmp src_reg, r11 */
   1285				maybe_emit_mod(&prog, src_reg, AUX_REG, true);
   1286				EMIT2(0x39, add_2reg(0xC0, src_reg, AUX_REG));
   1287				/* if unsigned '<' goto end_of_jmp2 */
   1288				EMIT2(X86_JB, 0);
   1289				end_of_jmp1 = prog;
   1290
   1291				/* mov r11, src_reg */
   1292				emit_mov_reg(&prog, true, AUX_REG, src_reg);
   1293				/* add r11, insn->off */
   1294				maybe_emit_1mod(&prog, AUX_REG, true);
   1295				EMIT2_off32(0x81, add_1reg(0xC0, AUX_REG), insn->off);
   1296				/* jmp if not carry to start_of_ldx
   1297				 * Otherwise ERR_PTR(-EINVAL) + 128 will be the user addr
   1298				 * that has to be rejected.
   1299				 */
   1300				EMIT2(0x73 /* JNC */, 0);
   1301				end_of_jmp2 = prog;
   1302
   1303				/* xor dst_reg, dst_reg */
   1304				emit_mov_imm32(&prog, false, dst_reg, 0);
   1305				/* jmp byte_after_ldx */
   1306				EMIT2(0xEB, 0);
   1307
   1308				/* populate jmp_offset for JB above to jump to xor dst_reg */
   1309				end_of_jmp1[-1] = end_of_jmp2 - end_of_jmp1;
   1310				/* populate jmp_offset for JNC above to jump to start_of_ldx */
   1311				start_of_ldx = prog;
   1312				end_of_jmp2[-1] = start_of_ldx - end_of_jmp2;
   1313			}
   1314			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
   1315			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
   1316				struct exception_table_entry *ex;
   1317				u8 *_insn = image + proglen + (start_of_ldx - temp);
   1318				s64 delta;
   1319
   1320				/* populate jmp_offset for JMP above */
   1321				start_of_ldx[-1] = prog - start_of_ldx;
   1322
   1323				if (!bpf_prog->aux->extable)
   1324					break;
   1325
   1326				if (excnt >= bpf_prog->aux->num_exentries) {
   1327					pr_err("ex gen bug\n");
   1328					return -EFAULT;
   1329				}
   1330				ex = &bpf_prog->aux->extable[excnt++];
   1331
   1332				delta = _insn - (u8 *)&ex->insn;
   1333				if (!is_simm32(delta)) {
   1334					pr_err("extable->insn doesn't fit into 32-bit\n");
   1335					return -EFAULT;
   1336				}
   1337				/* switch ex to rw buffer for writes */
   1338				ex = (void *)rw_image + ((void *)ex - (void *)image);
   1339
   1340				ex->insn = delta;
   1341
   1342				ex->data = EX_TYPE_BPF;
   1343
   1344				if (dst_reg > BPF_REG_9) {
   1345					pr_err("verifier error\n");
   1346					return -EFAULT;
   1347				}
   1348				/*
   1349				 * Compute size of x86 insn and its target dest x86 register.
   1350				 * ex_handler_bpf() will use lower 8 bits to adjust
   1351				 * pt_regs->ip to jump over this x86 instruction
   1352				 * and upper bits to figure out which pt_regs to zero out.
   1353				 * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
   1354				 * of 4 bytes will be ignored and rbx will be zero inited.
   1355				 */
   1356				ex->fixup = (prog - start_of_ldx) | (reg2pt_regs[dst_reg] << 8);
   1357			}
   1358			break;
   1359
   1360		case BPF_STX | BPF_ATOMIC | BPF_W:
   1361		case BPF_STX | BPF_ATOMIC | BPF_DW:
   1362			if (insn->imm == (BPF_AND | BPF_FETCH) ||
   1363			    insn->imm == (BPF_OR | BPF_FETCH) ||
   1364			    insn->imm == (BPF_XOR | BPF_FETCH)) {
   1365				bool is64 = BPF_SIZE(insn->code) == BPF_DW;
   1366				u32 real_src_reg = src_reg;
   1367				u32 real_dst_reg = dst_reg;
   1368				u8 *branch_target;
   1369
   1370				/*
   1371				 * Can't be implemented with a single x86 insn.
   1372				 * Need to do a CMPXCHG loop.
   1373				 */
   1374
   1375				/* Will need RAX as a CMPXCHG operand so save R0 */
   1376				emit_mov_reg(&prog, true, BPF_REG_AX, BPF_REG_0);
   1377				if (src_reg == BPF_REG_0)
   1378					real_src_reg = BPF_REG_AX;
   1379				if (dst_reg == BPF_REG_0)
   1380					real_dst_reg = BPF_REG_AX;
   1381
   1382				branch_target = prog;
   1383				/* Load old value */
   1384				emit_ldx(&prog, BPF_SIZE(insn->code),
   1385					 BPF_REG_0, real_dst_reg, insn->off);
   1386				/*
   1387				 * Perform the (commutative) operation locally,
   1388				 * put the result in the AUX_REG.
   1389				 */
   1390				emit_mov_reg(&prog, is64, AUX_REG, BPF_REG_0);
   1391				maybe_emit_mod(&prog, AUX_REG, real_src_reg, is64);
   1392				EMIT2(simple_alu_opcodes[BPF_OP(insn->imm)],
   1393				      add_2reg(0xC0, AUX_REG, real_src_reg));
   1394				/* Attempt to swap in new value */
   1395				err = emit_atomic(&prog, BPF_CMPXCHG,
   1396						  real_dst_reg, AUX_REG,
   1397						  insn->off,
   1398						  BPF_SIZE(insn->code));
   1399				if (WARN_ON(err))
   1400					return err;
   1401				/*
   1402				 * ZF tells us whether we won the race. If it's
   1403				 * cleared we need to try again.
   1404				 */
   1405				EMIT2(X86_JNE, -(prog - branch_target) - 2);
   1406				/* Return the pre-modification value */
   1407				emit_mov_reg(&prog, is64, real_src_reg, BPF_REG_0);
   1408				/* Restore R0 after clobbering RAX */
   1409				emit_mov_reg(&prog, true, BPF_REG_0, BPF_REG_AX);
   1410				break;
   1411			}
   1412
   1413			err = emit_atomic(&prog, insn->imm, dst_reg, src_reg,
   1414					  insn->off, BPF_SIZE(insn->code));
   1415			if (err)
   1416				return err;
   1417			break;
   1418
   1419			/* call */
   1420		case BPF_JMP | BPF_CALL:
   1421			func = (u8 *) __bpf_call_base + imm32;
   1422			if (tail_call_reachable) {
   1423				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
   1424				EMIT3_off32(0x48, 0x8B, 0x85,
   1425					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
   1426				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
   1427					return -EINVAL;
   1428			} else {
   1429				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
   1430					return -EINVAL;
   1431			}
   1432			break;
   1433
   1434		case BPF_JMP | BPF_TAIL_CALL:
   1435			if (imm32)
   1436				emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
   1437							  &prog, image + addrs[i - 1],
   1438							  callee_regs_used,
   1439							  bpf_prog->aux->stack_depth,
   1440							  ctx);
   1441			else
   1442				emit_bpf_tail_call_indirect(&prog,
   1443							    callee_regs_used,
   1444							    bpf_prog->aux->stack_depth,
   1445							    image + addrs[i - 1],
   1446							    ctx);
   1447			break;
   1448
   1449			/* cond jump */
   1450		case BPF_JMP | BPF_JEQ | BPF_X:
   1451		case BPF_JMP | BPF_JNE | BPF_X:
   1452		case BPF_JMP | BPF_JGT | BPF_X:
   1453		case BPF_JMP | BPF_JLT | BPF_X:
   1454		case BPF_JMP | BPF_JGE | BPF_X:
   1455		case BPF_JMP | BPF_JLE | BPF_X:
   1456		case BPF_JMP | BPF_JSGT | BPF_X:
   1457		case BPF_JMP | BPF_JSLT | BPF_X:
   1458		case BPF_JMP | BPF_JSGE | BPF_X:
   1459		case BPF_JMP | BPF_JSLE | BPF_X:
   1460		case BPF_JMP32 | BPF_JEQ | BPF_X:
   1461		case BPF_JMP32 | BPF_JNE | BPF_X:
   1462		case BPF_JMP32 | BPF_JGT | BPF_X:
   1463		case BPF_JMP32 | BPF_JLT | BPF_X:
   1464		case BPF_JMP32 | BPF_JGE | BPF_X:
   1465		case BPF_JMP32 | BPF_JLE | BPF_X:
   1466		case BPF_JMP32 | BPF_JSGT | BPF_X:
   1467		case BPF_JMP32 | BPF_JSLT | BPF_X:
   1468		case BPF_JMP32 | BPF_JSGE | BPF_X:
   1469		case BPF_JMP32 | BPF_JSLE | BPF_X:
   1470			/* cmp dst_reg, src_reg */
   1471			maybe_emit_mod(&prog, dst_reg, src_reg,
   1472				       BPF_CLASS(insn->code) == BPF_JMP);
   1473			EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
   1474			goto emit_cond_jmp;
   1475
   1476		case BPF_JMP | BPF_JSET | BPF_X:
   1477		case BPF_JMP32 | BPF_JSET | BPF_X:
   1478			/* test dst_reg, src_reg */
   1479			maybe_emit_mod(&prog, dst_reg, src_reg,
   1480				       BPF_CLASS(insn->code) == BPF_JMP);
   1481			EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
   1482			goto emit_cond_jmp;
   1483
   1484		case BPF_JMP | BPF_JSET | BPF_K:
   1485		case BPF_JMP32 | BPF_JSET | BPF_K:
   1486			/* test dst_reg, imm32 */
   1487			maybe_emit_1mod(&prog, dst_reg,
   1488					BPF_CLASS(insn->code) == BPF_JMP);
   1489			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
   1490			goto emit_cond_jmp;
   1491
   1492		case BPF_JMP | BPF_JEQ | BPF_K:
   1493		case BPF_JMP | BPF_JNE | BPF_K:
   1494		case BPF_JMP | BPF_JGT | BPF_K:
   1495		case BPF_JMP | BPF_JLT | BPF_K:
   1496		case BPF_JMP | BPF_JGE | BPF_K:
   1497		case BPF_JMP | BPF_JLE | BPF_K:
   1498		case BPF_JMP | BPF_JSGT | BPF_K:
   1499		case BPF_JMP | BPF_JSLT | BPF_K:
   1500		case BPF_JMP | BPF_JSGE | BPF_K:
   1501		case BPF_JMP | BPF_JSLE | BPF_K:
   1502		case BPF_JMP32 | BPF_JEQ | BPF_K:
   1503		case BPF_JMP32 | BPF_JNE | BPF_K:
   1504		case BPF_JMP32 | BPF_JGT | BPF_K:
   1505		case BPF_JMP32 | BPF_JLT | BPF_K:
   1506		case BPF_JMP32 | BPF_JGE | BPF_K:
   1507		case BPF_JMP32 | BPF_JLE | BPF_K:
   1508		case BPF_JMP32 | BPF_JSGT | BPF_K:
   1509		case BPF_JMP32 | BPF_JSLT | BPF_K:
   1510		case BPF_JMP32 | BPF_JSGE | BPF_K:
   1511		case BPF_JMP32 | BPF_JSLE | BPF_K:
   1512			/* test dst_reg, dst_reg to save one extra byte */
   1513			if (imm32 == 0) {
   1514				maybe_emit_mod(&prog, dst_reg, dst_reg,
   1515					       BPF_CLASS(insn->code) == BPF_JMP);
   1516				EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
   1517				goto emit_cond_jmp;
   1518			}
   1519
   1520			/* cmp dst_reg, imm8/32 */
   1521			maybe_emit_1mod(&prog, dst_reg,
   1522					BPF_CLASS(insn->code) == BPF_JMP);
   1523
   1524			if (is_imm8(imm32))
   1525				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
   1526			else
   1527				EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
   1528
   1529emit_cond_jmp:		/* Convert BPF opcode to x86 */
   1530			switch (BPF_OP(insn->code)) {
   1531			case BPF_JEQ:
   1532				jmp_cond = X86_JE;
   1533				break;
   1534			case BPF_JSET:
   1535			case BPF_JNE:
   1536				jmp_cond = X86_JNE;
   1537				break;
   1538			case BPF_JGT:
   1539				/* GT is unsigned '>', JA in x86 */
   1540				jmp_cond = X86_JA;
   1541				break;
   1542			case BPF_JLT:
   1543				/* LT is unsigned '<', JB in x86 */
   1544				jmp_cond = X86_JB;
   1545				break;
   1546			case BPF_JGE:
   1547				/* GE is unsigned '>=', JAE in x86 */
   1548				jmp_cond = X86_JAE;
   1549				break;
   1550			case BPF_JLE:
   1551				/* LE is unsigned '<=', JBE in x86 */
   1552				jmp_cond = X86_JBE;
   1553				break;
   1554			case BPF_JSGT:
   1555				/* Signed '>', GT in x86 */
   1556				jmp_cond = X86_JG;
   1557				break;
   1558			case BPF_JSLT:
   1559				/* Signed '<', LT in x86 */
   1560				jmp_cond = X86_JL;
   1561				break;
   1562			case BPF_JSGE:
   1563				/* Signed '>=', GE in x86 */
   1564				jmp_cond = X86_JGE;
   1565				break;
   1566			case BPF_JSLE:
   1567				/* Signed '<=', LE in x86 */
   1568				jmp_cond = X86_JLE;
   1569				break;
   1570			default: /* to silence GCC warning */
   1571				return -EFAULT;
   1572			}
   1573			jmp_offset = addrs[i + insn->off] - addrs[i];
   1574			if (is_imm8(jmp_offset)) {
   1575				if (jmp_padding) {
   1576					/* To keep the jmp_offset valid, the extra bytes are
   1577					 * padded before the jump insn, so we subtract the
   1578					 * 2 bytes of jmp_cond insn from INSN_SZ_DIFF.
   1579					 *
   1580					 * If the previous pass already emits an imm8
   1581					 * jmp_cond, then this BPF insn won't shrink, so
   1582					 * "nops" is 0.
   1583					 *
   1584					 * On the other hand, if the previous pass emits an
   1585					 * imm32 jmp_cond, the extra 4 bytes(*) is padded to
   1586					 * keep the image from shrinking further.
   1587					 *
   1588					 * (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond
   1589					 *     is 2 bytes, so the size difference is 4 bytes.
   1590					 */
   1591					nops = INSN_SZ_DIFF - 2;
   1592					if (nops != 0 && nops != 4) {
   1593						pr_err("unexpected jmp_cond padding: %d bytes\n",
   1594						       nops);
   1595						return -EFAULT;
   1596					}
   1597					emit_nops(&prog, nops);
   1598				}
   1599				EMIT2(jmp_cond, jmp_offset);
   1600			} else if (is_simm32(jmp_offset)) {
   1601				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
   1602			} else {
   1603				pr_err("cond_jmp gen bug %llx\n", jmp_offset);
   1604				return -EFAULT;
   1605			}
   1606
   1607			break;
   1608
   1609		case BPF_JMP | BPF_JA:
   1610			if (insn->off == -1)
   1611				/* -1 jmp instructions will always jump
   1612				 * backwards two bytes. Explicitly handling
   1613				 * this case avoids wasting too many passes
   1614				 * when there are long sequences of replaced
   1615				 * dead code.
   1616				 */
   1617				jmp_offset = -2;
   1618			else
   1619				jmp_offset = addrs[i + insn->off] - addrs[i];
   1620
   1621			if (!jmp_offset) {
   1622				/*
   1623				 * If jmp_padding is enabled, the extra nops will
   1624				 * be inserted. Otherwise, optimize out nop jumps.
   1625				 */
   1626				if (jmp_padding) {
   1627					/* There are 3 possible conditions.
   1628					 * (1) This BPF_JA is already optimized out in
   1629					 *     the previous run, so there is no need
   1630					 *     to pad any extra byte (0 byte).
   1631					 * (2) The previous pass emits an imm8 jmp,
   1632					 *     so we pad 2 bytes to match the previous
   1633					 *     insn size.
   1634					 * (3) Similarly, the previous pass emits an
   1635					 *     imm32 jmp, and 5 bytes is padded.
   1636					 */
   1637					nops = INSN_SZ_DIFF;
   1638					if (nops != 0 && nops != 2 && nops != 5) {
   1639						pr_err("unexpected nop jump padding: %d bytes\n",
   1640						       nops);
   1641						return -EFAULT;
   1642					}
   1643					emit_nops(&prog, nops);
   1644				}
   1645				break;
   1646			}
   1647emit_jmp:
   1648			if (is_imm8(jmp_offset)) {
   1649				if (jmp_padding) {
   1650					/* To avoid breaking jmp_offset, the extra bytes
   1651					 * are padded before the actual jmp insn, so
   1652					 * 2 bytes is subtracted from INSN_SZ_DIFF.
   1653					 *
   1654					 * If the previous pass already emits an imm8
   1655					 * jmp, there is nothing to pad (0 byte).
   1656					 *
   1657					 * If it emits an imm32 jmp (5 bytes) previously
   1658					 * and now an imm8 jmp (2 bytes), then we pad
   1659					 * (5 - 2 = 3) bytes to stop the image from
   1660					 * shrinking further.
   1661					 */
   1662					nops = INSN_SZ_DIFF - 2;
   1663					if (nops != 0 && nops != 3) {
   1664						pr_err("unexpected jump padding: %d bytes\n",
   1665						       nops);
   1666						return -EFAULT;
   1667					}
   1668					emit_nops(&prog, INSN_SZ_DIFF - 2);
   1669				}
   1670				EMIT2(0xEB, jmp_offset);
   1671			} else if (is_simm32(jmp_offset)) {
   1672				EMIT1_off32(0xE9, jmp_offset);
   1673			} else {
   1674				pr_err("jmp gen bug %llx\n", jmp_offset);
   1675				return -EFAULT;
   1676			}
   1677			break;
   1678
   1679		case BPF_JMP | BPF_EXIT:
   1680			if (seen_exit) {
   1681				jmp_offset = ctx->cleanup_addr - addrs[i];
   1682				goto emit_jmp;
   1683			}
   1684			seen_exit = true;
   1685			/* Update cleanup_addr */
   1686			ctx->cleanup_addr = proglen;
   1687			pop_callee_regs(&prog, callee_regs_used);
   1688			EMIT1(0xC9);         /* leave */
   1689			EMIT1(0xC3);         /* ret */
   1690			break;
   1691
   1692		default:
   1693			/*
   1694			 * By design x86-64 JIT should support all BPF instructions.
   1695			 * This error will be seen if new instruction was added
   1696			 * to the interpreter, but not to the JIT, or if there is
   1697			 * junk in bpf_prog.
   1698			 */
   1699			pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
   1700			return -EINVAL;
   1701		}
   1702
   1703		ilen = prog - temp;
   1704		if (ilen > BPF_MAX_INSN_SIZE) {
   1705			pr_err("bpf_jit: fatal insn size error\n");
   1706			return -EFAULT;
   1707		}
   1708
   1709		if (image) {
   1710			/*
   1711			 * When populating the image, assert that:
   1712			 *
   1713			 *  i) We do not write beyond the allocated space, and
   1714			 * ii) addrs[i] did not change from the prior run, in order
   1715			 *     to validate assumptions made for computing branch
   1716			 *     displacements.
   1717			 */
   1718			if (unlikely(proglen + ilen > oldproglen ||
   1719				     proglen + ilen != addrs[i])) {
   1720				pr_err("bpf_jit: fatal error\n");
   1721				return -EFAULT;
   1722			}
   1723			memcpy(rw_image + proglen, temp, ilen);
   1724		}
   1725		proglen += ilen;
   1726		addrs[i] = proglen;
   1727		prog = temp;
   1728	}
   1729
   1730	if (image && excnt != bpf_prog->aux->num_exentries) {
   1731		pr_err("extable is not populated\n");
   1732		return -EFAULT;
   1733	}
   1734	return proglen;
   1735}
   1736
   1737static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
   1738		      int stack_size)
   1739{
   1740	int i;
   1741	/* Store function arguments to stack.
   1742	 * For a function that accepts two pointers the sequence will be:
   1743	 * mov QWORD PTR [rbp-0x10],rdi
   1744	 * mov QWORD PTR [rbp-0x8],rsi
   1745	 */
   1746	for (i = 0; i < min(nr_args, 6); i++)
   1747		emit_stx(prog, bytes_to_bpf_size(m->arg_size[i]),
   1748			 BPF_REG_FP,
   1749			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
   1750			 -(stack_size - i * 8));
   1751}
   1752
   1753static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
   1754			 int stack_size)
   1755{
   1756	int i;
   1757
   1758	/* Restore function arguments from stack.
   1759	 * For a function that accepts two pointers the sequence will be:
   1760	 * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
   1761	 * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
   1762	 */
   1763	for (i = 0; i < min(nr_args, 6); i++)
   1764		emit_ldx(prog, bytes_to_bpf_size(m->arg_size[i]),
   1765			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
   1766			 BPF_REG_FP,
   1767			 -(stack_size - i * 8));
   1768}
   1769
   1770static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
   1771			   struct bpf_tramp_link *l, int stack_size,
   1772			   int run_ctx_off, bool save_ret)
   1773{
   1774	u8 *prog = *pprog;
   1775	u8 *jmp_insn;
   1776	int ctx_cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
   1777	struct bpf_prog *p = l->link.prog;
   1778	u64 cookie = l->cookie;
   1779
   1780	/* mov rdi, cookie */
   1781	emit_mov_imm64(&prog, BPF_REG_1, (long) cookie >> 32, (u32) (long) cookie);
   1782
   1783	/* Prepare struct bpf_tramp_run_ctx.
   1784	 *
   1785	 * bpf_tramp_run_ctx is already preserved by
   1786	 * arch_prepare_bpf_trampoline().
   1787	 *
   1788	 * mov QWORD PTR [rbp - run_ctx_off + ctx_cookie_off], rdi
   1789	 */
   1790	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_1, -run_ctx_off + ctx_cookie_off);
   1791
   1792	/* arg1: mov rdi, progs[i] */
   1793	emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
   1794	/* arg2: lea rsi, [rbp - ctx_cookie_off] */
   1795	EMIT4(0x48, 0x8D, 0x75, -run_ctx_off);
   1796
   1797	if (emit_call(&prog,
   1798		      p->aux->sleepable ? __bpf_prog_enter_sleepable :
   1799		      __bpf_prog_enter, prog))
   1800			return -EINVAL;
   1801	/* remember prog start time returned by __bpf_prog_enter */
   1802	emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
   1803
   1804	/* if (__bpf_prog_enter*(prog) == 0)
   1805	 *	goto skip_exec_of_prog;
   1806	 */
   1807	EMIT3(0x48, 0x85, 0xC0);  /* test rax,rax */
   1808	/* emit 2 nops that will be replaced with JE insn */
   1809	jmp_insn = prog;
   1810	emit_nops(&prog, 2);
   1811
   1812	/* arg1: lea rdi, [rbp - stack_size] */
   1813	EMIT4(0x48, 0x8D, 0x7D, -stack_size);
   1814	/* arg2: progs[i]->insnsi for interpreter */
   1815	if (!p->jited)
   1816		emit_mov_imm64(&prog, BPF_REG_2,
   1817			       (long) p->insnsi >> 32,
   1818			       (u32) (long) p->insnsi);
   1819	/* call JITed bpf program or interpreter */
   1820	if (emit_call(&prog, p->bpf_func, prog))
   1821		return -EINVAL;
   1822
   1823	/*
   1824	 * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
   1825	 * of the previous call which is then passed on the stack to
   1826	 * the next BPF program.
   1827	 *
   1828	 * BPF_TRAMP_FENTRY trampoline may need to return the return
   1829	 * value of BPF_PROG_TYPE_STRUCT_OPS prog.
   1830	 */
   1831	if (save_ret)
   1832		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
   1833
   1834	/* replace 2 nops with JE insn, since jmp target is known */
   1835	jmp_insn[0] = X86_JE;
   1836	jmp_insn[1] = prog - jmp_insn - 2;
   1837
   1838	/* arg1: mov rdi, progs[i] */
   1839	emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
   1840	/* arg2: mov rsi, rbx <- start time in nsec */
   1841	emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
   1842	/* arg3: lea rdx, [rbp - run_ctx_off] */
   1843	EMIT4(0x48, 0x8D, 0x55, -run_ctx_off);
   1844	if (emit_call(&prog,
   1845		      p->aux->sleepable ? __bpf_prog_exit_sleepable :
   1846		      __bpf_prog_exit, prog))
   1847			return -EINVAL;
   1848
   1849	*pprog = prog;
   1850	return 0;
   1851}
   1852
   1853static void emit_align(u8 **pprog, u32 align)
   1854{
   1855	u8 *target, *prog = *pprog;
   1856
   1857	target = PTR_ALIGN(prog, align);
   1858	if (target != prog)
   1859		emit_nops(&prog, target - prog);
   1860
   1861	*pprog = prog;
   1862}
   1863
   1864static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
   1865{
   1866	u8 *prog = *pprog;
   1867	s64 offset;
   1868
   1869	offset = func - (ip + 2 + 4);
   1870	if (!is_simm32(offset)) {
   1871		pr_err("Target %p is out of range\n", func);
   1872		return -EINVAL;
   1873	}
   1874	EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
   1875	*pprog = prog;
   1876	return 0;
   1877}
   1878
   1879static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
   1880		      struct bpf_tramp_links *tl, int stack_size,
   1881		      int run_ctx_off, bool save_ret)
   1882{
   1883	int i;
   1884	u8 *prog = *pprog;
   1885
   1886	for (i = 0; i < tl->nr_links; i++) {
   1887		if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size,
   1888				    run_ctx_off, save_ret))
   1889			return -EINVAL;
   1890	}
   1891	*pprog = prog;
   1892	return 0;
   1893}
   1894
   1895static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
   1896			      struct bpf_tramp_links *tl, int stack_size,
   1897			      int run_ctx_off, u8 **branches)
   1898{
   1899	u8 *prog = *pprog;
   1900	int i;
   1901
   1902	/* The first fmod_ret program will receive a garbage return value.
   1903	 * Set this to 0 to avoid confusing the program.
   1904	 */
   1905	emit_mov_imm32(&prog, false, BPF_REG_0, 0);
   1906	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
   1907	for (i = 0; i < tl->nr_links; i++) {
   1908		if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size, run_ctx_off, true))
   1909			return -EINVAL;
   1910
   1911		/* mod_ret prog stored return value into [rbp - 8]. Emit:
   1912		 * if (*(u64 *)(rbp - 8) !=  0)
   1913		 *	goto do_fexit;
   1914		 */
   1915		/* cmp QWORD PTR [rbp - 0x8], 0x0 */
   1916		EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
   1917
   1918		/* Save the location of the branch and Generate 6 nops
   1919		 * (4 bytes for an offset and 2 bytes for the jump) These nops
   1920		 * are replaced with a conditional jump once do_fexit (i.e. the
   1921		 * start of the fexit invocation) is finalized.
   1922		 */
   1923		branches[i] = prog;
   1924		emit_nops(&prog, 4 + 2);
   1925	}
   1926
   1927	*pprog = prog;
   1928	return 0;
   1929}
   1930
   1931static bool is_valid_bpf_tramp_flags(unsigned int flags)
   1932{
   1933	if ((flags & BPF_TRAMP_F_RESTORE_REGS) &&
   1934	    (flags & BPF_TRAMP_F_SKIP_FRAME))
   1935		return false;
   1936
   1937	/*
   1938	 * BPF_TRAMP_F_RET_FENTRY_RET is only used by bpf_struct_ops,
   1939	 * and it must be used alone.
   1940	 */
   1941	if ((flags & BPF_TRAMP_F_RET_FENTRY_RET) &&
   1942	    (flags & ~BPF_TRAMP_F_RET_FENTRY_RET))
   1943		return false;
   1944
   1945	return true;
   1946}
   1947
   1948/* Example:
   1949 * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
   1950 * its 'struct btf_func_model' will be nr_args=2
   1951 * The assembly code when eth_type_trans is executing after trampoline:
   1952 *
   1953 * push rbp
   1954 * mov rbp, rsp
   1955 * sub rsp, 16                     // space for skb and dev
   1956 * push rbx                        // temp regs to pass start time
   1957 * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
   1958 * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
   1959 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
   1960 * mov rbx, rax                    // remember start time in bpf stats are enabled
   1961 * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
   1962 * call addr_of_jited_FENTRY_prog
   1963 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
   1964 * mov rsi, rbx                    // prog start time
   1965 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
   1966 * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
   1967 * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
   1968 * pop rbx
   1969 * leave
   1970 * ret
   1971 *
   1972 * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
   1973 * replaced with 'call generated_bpf_trampoline'. When it returns
   1974 * eth_type_trans will continue executing with original skb and dev pointers.
   1975 *
   1976 * The assembly code when eth_type_trans is called from trampoline:
   1977 *
   1978 * push rbp
   1979 * mov rbp, rsp
   1980 * sub rsp, 24                     // space for skb, dev, return value
   1981 * push rbx                        // temp regs to pass start time
   1982 * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
   1983 * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
   1984 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
   1985 * mov rbx, rax                    // remember start time if bpf stats are enabled
   1986 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
   1987 * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
   1988 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
   1989 * mov rsi, rbx                    // prog start time
   1990 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
   1991 * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
   1992 * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
   1993 * call eth_type_trans+5           // execute body of eth_type_trans
   1994 * mov qword ptr [rbp - 8], rax    // save return value
   1995 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
   1996 * mov rbx, rax                    // remember start time in bpf stats are enabled
   1997 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
   1998 * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
   1999 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
   2000 * mov rsi, rbx                    // prog start time
   2001 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
   2002 * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
   2003 * pop rbx
   2004 * leave
   2005 * add rsp, 8                      // skip eth_type_trans's frame
   2006 * ret                             // return to its caller
   2007 */
   2008int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
   2009				const struct btf_func_model *m, u32 flags,
   2010				struct bpf_tramp_links *tlinks,
   2011				void *orig_call)
   2012{
   2013	int ret, i, nr_args = m->nr_args;
   2014	int regs_off, ip_off, args_off, stack_size = nr_args * 8, run_ctx_off;
   2015	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
   2016	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
   2017	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
   2018	u8 **branches = NULL;
   2019	u8 *prog;
   2020	bool save_ret;
   2021
   2022	/* x86-64 supports up to 6 arguments. 7+ can be added in the future */
   2023	if (nr_args > 6)
   2024		return -ENOTSUPP;
   2025
   2026	if (!is_valid_bpf_tramp_flags(flags))
   2027		return -EINVAL;
   2028
   2029	/* Generated trampoline stack layout:
   2030	 *
   2031	 * RBP + 8         [ return address  ]
   2032	 * RBP + 0         [ RBP             ]
   2033	 *
   2034	 * RBP - 8         [ return value    ]  BPF_TRAMP_F_CALL_ORIG or
   2035	 *                                      BPF_TRAMP_F_RET_FENTRY_RET flags
   2036	 *
   2037	 *                 [ reg_argN        ]  always
   2038	 *                 [ ...             ]
   2039	 * RBP - regs_off  [ reg_arg1        ]  program's ctx pointer
   2040	 *
   2041	 * RBP - args_off  [ args count      ]  always
   2042	 *
   2043	 * RBP - ip_off    [ traced function ]  BPF_TRAMP_F_IP_ARG flag
   2044	 *
   2045	 * RBP - run_ctx_off [ bpf_tramp_run_ctx ]
   2046	 */
   2047
   2048	/* room for return value of orig_call or fentry prog */
   2049	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
   2050	if (save_ret)
   2051		stack_size += 8;
   2052
   2053	regs_off = stack_size;
   2054
   2055	/* args count  */
   2056	stack_size += 8;
   2057	args_off = stack_size;
   2058
   2059	if (flags & BPF_TRAMP_F_IP_ARG)
   2060		stack_size += 8; /* room for IP address argument */
   2061
   2062	ip_off = stack_size;
   2063
   2064	stack_size += (sizeof(struct bpf_tramp_run_ctx) + 7) & ~0x7;
   2065	run_ctx_off = stack_size;
   2066
   2067	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
   2068		/* skip patched call instruction and point orig_call to actual
   2069		 * body of the kernel function.
   2070		 */
   2071		if (is_endbr(*(u32 *)orig_call))
   2072			orig_call += ENDBR_INSN_SIZE;
   2073		orig_call += X86_PATCH_SIZE;
   2074	}
   2075
   2076	prog = image;
   2077
   2078	EMIT_ENDBR();
   2079	EMIT1(0x55);		 /* push rbp */
   2080	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
   2081	EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */
   2082	EMIT1(0x53);		 /* push rbx */
   2083
   2084	/* Store number of arguments of the traced function:
   2085	 *   mov rax, nr_args
   2086	 *   mov QWORD PTR [rbp - args_off], rax
   2087	 */
   2088	emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_args);
   2089	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -args_off);
   2090
   2091	if (flags & BPF_TRAMP_F_IP_ARG) {
   2092		/* Store IP address of the traced function:
   2093		 * mov rax, QWORD PTR [rbp + 8]
   2094		 * sub rax, X86_PATCH_SIZE
   2095		 * mov QWORD PTR [rbp - ip_off], rax
   2096		 */
   2097		emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
   2098		EMIT4(0x48, 0x83, 0xe8, X86_PATCH_SIZE);
   2099		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off);
   2100	}
   2101
   2102	save_regs(m, &prog, nr_args, regs_off);
   2103
   2104	if (flags & BPF_TRAMP_F_CALL_ORIG) {
   2105		/* arg1: mov rdi, im */
   2106		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
   2107		if (emit_call(&prog, __bpf_tramp_enter, prog)) {
   2108			ret = -EINVAL;
   2109			goto cleanup;
   2110		}
   2111	}
   2112
   2113	if (fentry->nr_links)
   2114		if (invoke_bpf(m, &prog, fentry, regs_off, run_ctx_off,
   2115			       flags & BPF_TRAMP_F_RET_FENTRY_RET))
   2116			return -EINVAL;
   2117
   2118	if (fmod_ret->nr_links) {
   2119		branches = kcalloc(fmod_ret->nr_links, sizeof(u8 *),
   2120				   GFP_KERNEL);
   2121		if (!branches)
   2122			return -ENOMEM;
   2123
   2124		if (invoke_bpf_mod_ret(m, &prog, fmod_ret, regs_off,
   2125				       run_ctx_off, branches)) {
   2126			ret = -EINVAL;
   2127			goto cleanup;
   2128		}
   2129	}
   2130
   2131	if (flags & BPF_TRAMP_F_CALL_ORIG) {
   2132		restore_regs(m, &prog, nr_args, regs_off);
   2133
   2134		/* call original function */
   2135		if (emit_call(&prog, orig_call, prog)) {
   2136			ret = -EINVAL;
   2137			goto cleanup;
   2138		}
   2139		/* remember return value in a stack for bpf prog to access */
   2140		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
   2141		im->ip_after_call = prog;
   2142		memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
   2143		prog += X86_PATCH_SIZE;
   2144	}
   2145
   2146	if (fmod_ret->nr_links) {
   2147		/* From Intel 64 and IA-32 Architectures Optimization
   2148		 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
   2149		 * Coding Rule 11: All branch targets should be 16-byte
   2150		 * aligned.
   2151		 */
   2152		emit_align(&prog, 16);
   2153		/* Update the branches saved in invoke_bpf_mod_ret with the
   2154		 * aligned address of do_fexit.
   2155		 */
   2156		for (i = 0; i < fmod_ret->nr_links; i++)
   2157			emit_cond_near_jump(&branches[i], prog, branches[i],
   2158					    X86_JNE);
   2159	}
   2160
   2161	if (fexit->nr_links)
   2162		if (invoke_bpf(m, &prog, fexit, regs_off, run_ctx_off, false)) {
   2163			ret = -EINVAL;
   2164			goto cleanup;
   2165		}
   2166
   2167	if (flags & BPF_TRAMP_F_RESTORE_REGS)
   2168		restore_regs(m, &prog, nr_args, regs_off);
   2169
   2170	/* This needs to be done regardless. If there were fmod_ret programs,
   2171	 * the return value is only updated on the stack and still needs to be
   2172	 * restored to R0.
   2173	 */
   2174	if (flags & BPF_TRAMP_F_CALL_ORIG) {
   2175		im->ip_epilogue = prog;
   2176		/* arg1: mov rdi, im */
   2177		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
   2178		if (emit_call(&prog, __bpf_tramp_exit, prog)) {
   2179			ret = -EINVAL;
   2180			goto cleanup;
   2181		}
   2182	}
   2183	/* restore return value of orig_call or fentry prog back into RAX */
   2184	if (save_ret)
   2185		emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
   2186
   2187	EMIT1(0x5B); /* pop rbx */
   2188	EMIT1(0xC9); /* leave */
   2189	if (flags & BPF_TRAMP_F_SKIP_FRAME)
   2190		/* skip our return address and return to parent */
   2191		EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
   2192	EMIT1(0xC3); /* ret */
   2193	/* Make sure the trampoline generation logic doesn't overflow */
   2194	if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) {
   2195		ret = -EFAULT;
   2196		goto cleanup;
   2197	}
   2198	ret = prog - (u8 *)image;
   2199
   2200cleanup:
   2201	kfree(branches);
   2202	return ret;
   2203}
   2204
   2205static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs)
   2206{
   2207	u8 *jg_reloc, *prog = *pprog;
   2208	int pivot, err, jg_bytes = 1;
   2209	s64 jg_offset;
   2210
   2211	if (a == b) {
   2212		/* Leaf node of recursion, i.e. not a range of indices
   2213		 * anymore.
   2214		 */
   2215		EMIT1(add_1mod(0x48, BPF_REG_3));	/* cmp rdx,func */
   2216		if (!is_simm32(progs[a]))
   2217			return -1;
   2218		EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
   2219			    progs[a]);
   2220		err = emit_cond_near_jump(&prog,	/* je func */
   2221					  (void *)progs[a], prog,
   2222					  X86_JE);
   2223		if (err)
   2224			return err;
   2225
   2226		emit_indirect_jump(&prog, 2 /* rdx */, prog);
   2227
   2228		*pprog = prog;
   2229		return 0;
   2230	}
   2231
   2232	/* Not a leaf node, so we pivot, and recursively descend into
   2233	 * the lower and upper ranges.
   2234	 */
   2235	pivot = (b - a) / 2;
   2236	EMIT1(add_1mod(0x48, BPF_REG_3));		/* cmp rdx,func */
   2237	if (!is_simm32(progs[a + pivot]))
   2238		return -1;
   2239	EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
   2240
   2241	if (pivot > 2) {				/* jg upper_part */
   2242		/* Require near jump. */
   2243		jg_bytes = 4;
   2244		EMIT2_off32(0x0F, X86_JG + 0x10, 0);
   2245	} else {
   2246		EMIT2(X86_JG, 0);
   2247	}
   2248	jg_reloc = prog;
   2249
   2250	err = emit_bpf_dispatcher(&prog, a, a + pivot,	/* emit lower_part */
   2251				  progs);
   2252	if (err)
   2253		return err;
   2254
   2255	/* From Intel 64 and IA-32 Architectures Optimization
   2256	 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
   2257	 * Coding Rule 11: All branch targets should be 16-byte
   2258	 * aligned.
   2259	 */
   2260	emit_align(&prog, 16);
   2261	jg_offset = prog - jg_reloc;
   2262	emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
   2263
   2264	err = emit_bpf_dispatcher(&prog, a + pivot + 1,	/* emit upper_part */
   2265				  b, progs);
   2266	if (err)
   2267		return err;
   2268
   2269	*pprog = prog;
   2270	return 0;
   2271}
   2272
   2273static int cmp_ips(const void *a, const void *b)
   2274{
   2275	const s64 *ipa = a;
   2276	const s64 *ipb = b;
   2277
   2278	if (*ipa > *ipb)
   2279		return 1;
   2280	if (*ipa < *ipb)
   2281		return -1;
   2282	return 0;
   2283}
   2284
   2285int arch_prepare_bpf_dispatcher(void *image, s64 *funcs, int num_funcs)
   2286{
   2287	u8 *prog = image;
   2288
   2289	sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
   2290	return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs);
   2291}
   2292
   2293struct x64_jit_data {
   2294	struct bpf_binary_header *rw_header;
   2295	struct bpf_binary_header *header;
   2296	int *addrs;
   2297	u8 *image;
   2298	int proglen;
   2299	struct jit_context ctx;
   2300};
   2301
   2302#define MAX_PASSES 20
   2303#define PADDING_PASSES (MAX_PASSES - 5)
   2304
   2305struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
   2306{
   2307	struct bpf_binary_header *rw_header = NULL;
   2308	struct bpf_binary_header *header = NULL;
   2309	struct bpf_prog *tmp, *orig_prog = prog;
   2310	struct x64_jit_data *jit_data;
   2311	int proglen, oldproglen = 0;
   2312	struct jit_context ctx = {};
   2313	bool tmp_blinded = false;
   2314	bool extra_pass = false;
   2315	bool padding = false;
   2316	u8 *rw_image = NULL;
   2317	u8 *image = NULL;
   2318	int *addrs;
   2319	int pass;
   2320	int i;
   2321
   2322	if (!prog->jit_requested)
   2323		return orig_prog;
   2324
   2325	tmp = bpf_jit_blind_constants(prog);
   2326	/*
   2327	 * If blinding was requested and we failed during blinding,
   2328	 * we must fall back to the interpreter.
   2329	 */
   2330	if (IS_ERR(tmp))
   2331		return orig_prog;
   2332	if (tmp != prog) {
   2333		tmp_blinded = true;
   2334		prog = tmp;
   2335	}
   2336
   2337	jit_data = prog->aux->jit_data;
   2338	if (!jit_data) {
   2339		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
   2340		if (!jit_data) {
   2341			prog = orig_prog;
   2342			goto out;
   2343		}
   2344		prog->aux->jit_data = jit_data;
   2345	}
   2346	addrs = jit_data->addrs;
   2347	if (addrs) {
   2348		ctx = jit_data->ctx;
   2349		oldproglen = jit_data->proglen;
   2350		image = jit_data->image;
   2351		header = jit_data->header;
   2352		rw_header = jit_data->rw_header;
   2353		rw_image = (void *)rw_header + ((void *)image - (void *)header);
   2354		extra_pass = true;
   2355		padding = true;
   2356		goto skip_init_addrs;
   2357	}
   2358	addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
   2359	if (!addrs) {
   2360		prog = orig_prog;
   2361		goto out_addrs;
   2362	}
   2363
   2364	/*
   2365	 * Before first pass, make a rough estimation of addrs[]
   2366	 * each BPF instruction is translated to less than 64 bytes
   2367	 */
   2368	for (proglen = 0, i = 0; i <= prog->len; i++) {
   2369		proglen += 64;
   2370		addrs[i] = proglen;
   2371	}
   2372	ctx.cleanup_addr = proglen;
   2373skip_init_addrs:
   2374
   2375	/*
   2376	 * JITed image shrinks with every pass and the loop iterates
   2377	 * until the image stops shrinking. Very large BPF programs
   2378	 * may converge on the last pass. In such case do one more
   2379	 * pass to emit the final image.
   2380	 */
   2381	for (pass = 0; pass < MAX_PASSES || image; pass++) {
   2382		if (!padding && pass >= PADDING_PASSES)
   2383			padding = true;
   2384		proglen = do_jit(prog, addrs, image, rw_image, oldproglen, &ctx, padding);
   2385		if (proglen <= 0) {
   2386out_image:
   2387			image = NULL;
   2388			if (header) {
   2389				bpf_arch_text_copy(&header->size, &rw_header->size,
   2390						   sizeof(rw_header->size));
   2391				bpf_jit_binary_pack_free(header, rw_header);
   2392			}
   2393			/* Fall back to interpreter mode */
   2394			prog = orig_prog;
   2395			if (extra_pass) {
   2396				prog->bpf_func = NULL;
   2397				prog->jited = 0;
   2398				prog->jited_len = 0;
   2399			}
   2400			goto out_addrs;
   2401		}
   2402		if (image) {
   2403			if (proglen != oldproglen) {
   2404				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
   2405				       proglen, oldproglen);
   2406				goto out_image;
   2407			}
   2408			break;
   2409		}
   2410		if (proglen == oldproglen) {
   2411			/*
   2412			 * The number of entries in extable is the number of BPF_LDX
   2413			 * insns that access kernel memory via "pointer to BTF type".
   2414			 * The verifier changed their opcode from LDX|MEM|size
   2415			 * to LDX|PROBE_MEM|size to make JITing easier.
   2416			 */
   2417			u32 align = __alignof__(struct exception_table_entry);
   2418			u32 extable_size = prog->aux->num_exentries *
   2419				sizeof(struct exception_table_entry);
   2420
   2421			/* allocate module memory for x86 insns and extable */
   2422			header = bpf_jit_binary_pack_alloc(roundup(proglen, align) + extable_size,
   2423							   &image, align, &rw_header, &rw_image,
   2424							   jit_fill_hole);
   2425			if (!header) {
   2426				prog = orig_prog;
   2427				goto out_addrs;
   2428			}
   2429			prog->aux->extable = (void *) image + roundup(proglen, align);
   2430		}
   2431		oldproglen = proglen;
   2432		cond_resched();
   2433	}
   2434
   2435	if (bpf_jit_enable > 1)
   2436		bpf_jit_dump(prog->len, proglen, pass + 1, image);
   2437
   2438	if (image) {
   2439		if (!prog->is_func || extra_pass) {
   2440			/*
   2441			 * bpf_jit_binary_pack_finalize fails in two scenarios:
   2442			 *   1) header is not pointing to proper module memory;
   2443			 *   2) the arch doesn't support bpf_arch_text_copy().
   2444			 *
   2445			 * Both cases are serious bugs and justify WARN_ON.
   2446			 */
   2447			if (WARN_ON(bpf_jit_binary_pack_finalize(prog, header, rw_header))) {
   2448				/* header has been freed */
   2449				header = NULL;
   2450				goto out_image;
   2451			}
   2452
   2453			bpf_tail_call_direct_fixup(prog);
   2454		} else {
   2455			jit_data->addrs = addrs;
   2456			jit_data->ctx = ctx;
   2457			jit_data->proglen = proglen;
   2458			jit_data->image = image;
   2459			jit_data->header = header;
   2460			jit_data->rw_header = rw_header;
   2461		}
   2462		prog->bpf_func = (void *)image;
   2463		prog->jited = 1;
   2464		prog->jited_len = proglen;
   2465	} else {
   2466		prog = orig_prog;
   2467	}
   2468
   2469	if (!image || !prog->is_func || extra_pass) {
   2470		if (image)
   2471			bpf_prog_fill_jited_linfo(prog, addrs + 1);
   2472out_addrs:
   2473		kvfree(addrs);
   2474		kfree(jit_data);
   2475		prog->aux->jit_data = NULL;
   2476	}
   2477out:
   2478	if (tmp_blinded)
   2479		bpf_jit_prog_release_other(prog, prog == orig_prog ?
   2480					   tmp : orig_prog);
   2481	return prog;
   2482}
   2483
   2484bool bpf_jit_supports_kfunc_call(void)
   2485{
   2486	return true;
   2487}
   2488
   2489void *bpf_arch_text_copy(void *dst, void *src, size_t len)
   2490{
   2491	if (text_poke_copy(dst, src, len) == NULL)
   2492		return ERR_PTR(-EINVAL);
   2493	return dst;
   2494}