From b4e432f1000a171d901e42551459059831925770 Mon Sep 17 00:00:00 2001 From: Daniel Borkmann Date: Thu, 10 Aug 2017 01:40:02 +0200 Subject: [PATCH] bpf: enable BPF_J{LT, LE, SLT, SLE} opcodes in verifier Enable the newly added jump opcodes, main parts are in two different areas, namely direct packet access and dynamic map value access. For the direct packet access, we now allow for the following two new patterns to match in order to trigger markings with find_good_pkt_pointers(): Variant 1 (access ok when taking the branch): 0: (61) r2 = *(u32 *)(r1 +76) 1: (61) r3 = *(u32 *)(r1 +80) 2: (bf) r0 = r2 3: (07) r0 += 8 4: (ad) if r0 < r3 goto pc+2 R0=pkt(id=0,off=8,r=0) R1=ctx R2=pkt(id=0,off=0,r=0) R3=pkt_end R10=fp 5: (b7) r0 = 0 6: (95) exit from 4 to 7: R0=pkt(id=0,off=8,r=8) R1=ctx R2=pkt(id=0,off=0,r=8) R3=pkt_end R10=fp 7: (71) r0 = *(u8 *)(r2 +0) 8: (05) goto pc-4 5: (b7) r0 = 0 6: (95) exit processed 11 insns, stack depth 0 Variant 2 (access ok on fall-through): 0: (61) r2 = *(u32 *)(r1 +76) 1: (61) r3 = *(u32 *)(r1 +80) 2: (bf) r0 = r2 3: (07) r0 += 8 4: (bd) if r3 <= r0 goto pc+1 R0=pkt(id=0,off=8,r=8) R1=ctx R2=pkt(id=0,off=0,r=8) R3=pkt_end R10=fp 5: (71) r0 = *(u8 *)(r2 +0) 6: (b7) r0 = 1 7: (95) exit from 4 to 6: R0=pkt(id=0,off=8,r=0) R1=ctx R2=pkt(id=0,off=0,r=0) R3=pkt_end R10=fp 6: (b7) r0 = 1 7: (95) exit processed 10 insns, stack depth 0 The above two basically just swap the branches where we need to handle an exception and allow packet access compared to the two already existing variants for find_good_pkt_pointers(). For the dynamic map value access, we add the new instructions to reg_set_min_max() and reg_set_min_max_inv() in order to learn bounds. Verifier test cases for both are added in a follow-up patch. Signed-off-by: Daniel Borkmann Acked-by: Alexei Starovoitov Acked-by: John Fastabend Signed-off-by: David S. Miller --- kernel/bpf/verifier.c | 62 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 8160a81a40bf..ecc590e01a1d 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -312,11 +312,15 @@ static const char *const bpf_jmp_string[16] = { [BPF_JA >> 4] = "jmp", [BPF_JEQ >> 4] = "==", [BPF_JGT >> 4] = ">", + [BPF_JLT >> 4] = "<", [BPF_JGE >> 4] = ">=", + [BPF_JLE >> 4] = "<=", [BPF_JSET >> 4] = "&", [BPF_JNE >> 4] = "!=", [BPF_JSGT >> 4] = "s>", + [BPF_JSLT >> 4] = "s<", [BPF_JSGE >> 4] = "s>=", + [BPF_JSLE >> 4] = "s<=", [BPF_CALL >> 4] = "call", [BPF_EXIT >> 4] = "exit", }; @@ -2383,27 +2387,37 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *state, */ return; - /* LLVM can generate two kind of checks: + /* LLVM can generate four kind of checks: * - * Type 1: + * Type 1/2: * * r2 = r3; * r2 += 8; * if (r2 > pkt_end) goto * * + * r2 = r3; + * r2 += 8; + * if (r2 < pkt_end) goto + * + * * Where: * r2 == dst_reg, pkt_end == src_reg * r2=pkt(id=n,off=8,r=0) * r3=pkt(id=n,off=0,r=0) * - * Type 2: + * Type 3/4: * * r2 = r3; * r2 += 8; * if (pkt_end >= r2) goto * * + * r2 = r3; + * r2 += 8; + * if (pkt_end <= r2) goto + * + * * Where: * pkt_end == dst_reg, r2 == src_reg * r2=pkt(id=n,off=8,r=0) @@ -2471,6 +2485,14 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg, false_reg->smax_value = min_t(s64, false_reg->smax_value, val); true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1); break; + case BPF_JLT: + false_reg->umin_value = max(false_reg->umin_value, val); + true_reg->umax_value = min(true_reg->umax_value, val - 1); + break; + case BPF_JSLT: + false_reg->smin_value = max_t(s64, false_reg->smin_value, val); + true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1); + break; case BPF_JGE: false_reg->umax_value = min(false_reg->umax_value, val - 1); true_reg->umin_value = max(true_reg->umin_value, val); @@ -2479,6 +2501,14 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg, false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1); true_reg->smin_value = max_t(s64, true_reg->smin_value, val); break; + case BPF_JLE: + false_reg->umin_value = max(false_reg->umin_value, val + 1); + true_reg->umax_value = min(true_reg->umax_value, val); + break; + case BPF_JSLE: + false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1); + true_reg->smax_value = min_t(s64, true_reg->smax_value, val); + break; default: break; } @@ -2527,6 +2557,14 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1); false_reg->smin_value = max_t(s64, false_reg->smin_value, val); break; + case BPF_JLT: + true_reg->umin_value = max(true_reg->umin_value, val + 1); + false_reg->umax_value = min(false_reg->umax_value, val); + break; + case BPF_JSLT: + true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1); + false_reg->smax_value = min_t(s64, false_reg->smax_value, val); + break; case BPF_JGE: true_reg->umax_value = min(true_reg->umax_value, val); false_reg->umin_value = max(false_reg->umin_value, val + 1); @@ -2535,6 +2573,14 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, true_reg->smax_value = min_t(s64, true_reg->smax_value, val); false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1); break; + case BPF_JLE: + true_reg->umin_value = max(true_reg->umin_value, val); + false_reg->umax_value = min(false_reg->umax_value, val - 1); + break; + case BPF_JSLE: + true_reg->smin_value = max_t(s64, true_reg->smin_value, val); + false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1); + break; default: break; } @@ -2659,7 +2705,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, u8 opcode = BPF_OP(insn->code); int err; - if (opcode > BPF_EXIT) { + if (opcode > BPF_JSLE) { verbose("invalid BPF_JMP opcode %x\n", opcode); return -EINVAL; } @@ -2761,10 +2807,18 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, dst_reg->type == PTR_TO_PACKET && regs[insn->src_reg].type == PTR_TO_PACKET_END) { find_good_pkt_pointers(this_branch, dst_reg); + } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JLT && + dst_reg->type == PTR_TO_PACKET && + regs[insn->src_reg].type == PTR_TO_PACKET_END) { + find_good_pkt_pointers(other_branch, dst_reg); } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGE && dst_reg->type == PTR_TO_PACKET_END && regs[insn->src_reg].type == PTR_TO_PACKET) { find_good_pkt_pointers(other_branch, ®s[insn->src_reg]); + } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JLE && + dst_reg->type == PTR_TO_PACKET_END && + regs[insn->src_reg].type == PTR_TO_PACKET) { + find_good_pkt_pointers(this_branch, ®s[insn->src_reg]); } else if (is_pointer_value(env, insn->dst_reg)) { verbose("R%d pointer comparison prohibited\n", insn->dst_reg); return -EACCES;