diff mbox series

[bpf-next] bpf: Fix register equivalence tracking.

Message ID 20201014175608.1416-1-alexei.starovoitov@gmail.com
State New
Headers show
Series [bpf-next] bpf: Fix register equivalence tracking. | expand

Commit Message

Alexei Starovoitov Oct. 14, 2020, 5:56 p.m. UTC
From: Alexei Starovoitov <ast@kernel.org>

The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either
true or false branch. In the case 'if (reg->id)' check was done on the other
branch the counter part register would have reg->id == 0 when called into
find_equal_scalars(). In such case the helper would incorrectly identify other
registers with id == 0 as equivalent and propagate the state incorrectly.
Fix it by preserving ID across reg_set_min_max().
In other words any kind of comparison operator on the scalar register
should preserve its ID to recognize:
r1 = r2
if (r1 == 20) {
  #1 here both r1 and r2 == 20
} else if (r2 < 20) {
  #2 here both r1 and r2 < 20
}

The patch is addressing #1 case. The #2 was working correctly already.

Fixes: 75748837b7e5 ("bpf: Propagate scalar ranges through register assignments.")
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
---
 kernel/bpf/verifier.c                         | 38 ++++++++++++-------
 .../testing/selftests/bpf/verifier/regalloc.c | 26 +++++++++++++
 2 files changed, 51 insertions(+), 13 deletions(-)

Comments

Andrii Nakryiko Oct. 14, 2020, 11:09 p.m. UTC | #1
On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>
> From: Alexei Starovoitov <ast@kernel.org>
>
> The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either
> true or false branch. In the case 'if (reg->id)' check was done on the other
> branch the counter part register would have reg->id == 0 when called into
> find_equal_scalars(). In such case the helper would incorrectly identify other
> registers with id == 0 as equivalent and propagate the state incorrectly.
> Fix it by preserving ID across reg_set_min_max().
> In other words any kind of comparison operator on the scalar register
> should preserve its ID to recognize:
> r1 = r2
> if (r1 == 20) {
>   #1 here both r1 and r2 == 20
> } else if (r2 < 20) {
>   #2 here both r1 and r2 < 20
> }
>
> The patch is addressing #1 case. The #2 was working correctly already.
>
> Fixes: 75748837b7e5 ("bpf: Propagate scalar ranges through register assignments.")
> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
> ---

Number of underscores is a bit subtle a difference, but this fixes the bug, so:

Acked-by: Andrii Nakryiko <andrii@kernel.org>


>  kernel/bpf/verifier.c                         | 38 ++++++++++++-------
>  .../testing/selftests/bpf/verifier/regalloc.c | 26 +++++++++++++
>  2 files changed, 51 insertions(+), 13 deletions(-)
>

[...]
John Fastabend Oct. 15, 2020, 4:04 a.m. UTC | #2
Andrii Nakryiko wrote:
> On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov
> <alexei.starovoitov@gmail.com> wrote:
> >
> > From: Alexei Starovoitov <ast@kernel.org>
> >
> > The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either
> > true or false branch. In the case 'if (reg->id)' check was done on the other
> > branch the counter part register would have reg->id == 0 when called into
> > find_equal_scalars(). In such case the helper would incorrectly identify other
> > registers with id == 0 as equivalent and propagate the state incorrectly.

One thought. It seems we should never have reg->id=0 in find_equal_scalars()
would it be worthwhile to add an additional check here? Something like,

  if (known_reg->id == 0)
	return

Or even a WARN_ON_ONCE() there? Not sold either way, but maybe worth thinking
about.

> > Fix it by preserving ID across reg_set_min_max().
> > In other words any kind of comparison operator on the scalar register
> > should preserve its ID to recognize:
> > r1 = r2
> > if (r1 == 20) {
> >   #1 here both r1 and r2 == 20
> > } else if (r2 < 20) {
> >   #2 here both r1 and r2 < 20
> > }
> >
> > The patch is addressing #1 case. The #2 was working correctly already.
> >
> > Fixes: 75748837b7e5 ("bpf: Propagate scalar ranges through register assignments.")
> > Signed-off-by: Alexei Starovoitov <ast@kernel.org>
> > ---
> 
> Number of underscores is a bit subtle a difference, but this fixes the bug, so:
> 
> Acked-by: Andrii Nakryiko <andrii@kernel.org>
> 

Nice catch,

Acked-by: John Fastabend <john.fastabend@gmail.com>

> 
> >  kernel/bpf/verifier.c                         | 38 ++++++++++++-------
> >  .../testing/selftests/bpf/verifier/regalloc.c | 26 +++++++++++++
> >  2 files changed, 51 insertions(+), 13 deletions(-)
> >
> 
> [...]
Alexei Starovoitov Oct. 15, 2020, 4:19 a.m. UTC | #3
On Wed, Oct 14, 2020 at 09:04:23PM -0700, John Fastabend wrote:
> Andrii Nakryiko wrote:

> > On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov

> > <alexei.starovoitov@gmail.com> wrote:

> > >

> > > From: Alexei Starovoitov <ast@kernel.org>

> > >

> > > The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either

> > > true or false branch. In the case 'if (reg->id)' check was done on the other

> > > branch the counter part register would have reg->id == 0 when called into

> > > find_equal_scalars(). In such case the helper would incorrectly identify other

> > > registers with id == 0 as equivalent and propagate the state incorrectly.

> 

> One thought. It seems we should never have reg->id=0 in find_equal_scalars()

> would it be worthwhile to add an additional check here? Something like,

> 

>   if (known_reg->id == 0)

> 	return

>

> Or even a WARN_ON_ONCE() there? Not sold either way, but maybe worth thinking

> about.


That cannot happen anymore due to
if (dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id))
check in the caller.
I prefer not to repeat the same check twice. Also I really don't like defensive programming.
if (known_reg->id == 0)
       return;
is exactly that.
If we had that already, as Andrii argued in the original thread, we would have
never noticed this issue. <, >, <= ops would have worked, but == would be
sort-of working. It would mark one branch instead of both, and sometimes
neither of the branches. I'd rather have bugs like this one hurting and caught
quickly instead of warm feeling of being safe and sailing into unknown.
John Fastabend Oct. 15, 2020, 4:27 a.m. UTC | #4
Alexei Starovoitov wrote:
> On Wed, Oct 14, 2020 at 09:04:23PM -0700, John Fastabend wrote:
> > Andrii Nakryiko wrote:
> > > On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov
> > > <alexei.starovoitov@gmail.com> wrote:
> > > >
> > > > From: Alexei Starovoitov <ast@kernel.org>
> > > >
> > > > The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either
> > > > true or false branch. In the case 'if (reg->id)' check was done on the other
> > > > branch the counter part register would have reg->id == 0 when called into
> > > > find_equal_scalars(). In such case the helper would incorrectly identify other
> > > > registers with id == 0 as equivalent and propagate the state incorrectly.
> > 
> > One thought. It seems we should never have reg->id=0 in find_equal_scalars()
> > would it be worthwhile to add an additional check here? Something like,
> > 
> >   if (known_reg->id == 0)
> > 	return
> >
> > Or even a WARN_ON_ONCE() there? Not sold either way, but maybe worth thinking
> > about.
> 
> That cannot happen anymore due to
> if (dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id))
> check in the caller.
> I prefer not to repeat the same check twice. Also I really don't like defensive programming.
> if (known_reg->id == 0)
>        return;
> is exactly that.
> If we had that already, as Andrii argued in the original thread, we would have
> never noticed this issue. <, >, <= ops would have worked, but == would be
> sort-of working. It would mark one branch instead of both, and sometimes
> neither of the branches. I'd rather have bugs like this one hurting and caught
> quickly instead of warm feeling of being safe and sailing into unknown.

Agree. Although a WARN_ON_ONCE would have also been caught.
Alexei Starovoitov Oct. 15, 2020, 4:33 a.m. UTC | #5
On Wed, Oct 14, 2020 at 09:27:17PM -0700, John Fastabend wrote:
> Alexei Starovoitov wrote:

> > On Wed, Oct 14, 2020 at 09:04:23PM -0700, John Fastabend wrote:

> > > Andrii Nakryiko wrote:

> > > > On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov

> > > > <alexei.starovoitov@gmail.com> wrote:

> > > > >

> > > > > From: Alexei Starovoitov <ast@kernel.org>

> > > > >

> > > > > The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either

> > > > > true or false branch. In the case 'if (reg->id)' check was done on the other

> > > > > branch the counter part register would have reg->id == 0 when called into

> > > > > find_equal_scalars(). In such case the helper would incorrectly identify other

> > > > > registers with id == 0 as equivalent and propagate the state incorrectly.

> > > 

> > > One thought. It seems we should never have reg->id=0 in find_equal_scalars()

> > > would it be worthwhile to add an additional check here? Something like,

> > > 

> > >   if (known_reg->id == 0)

> > > 	return

> > >

> > > Or even a WARN_ON_ONCE() there? Not sold either way, but maybe worth thinking

> > > about.

> > 

> > That cannot happen anymore due to

> > if (dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id))

> > check in the caller.

> > I prefer not to repeat the same check twice. Also I really don't like defensive programming.

> > if (known_reg->id == 0)

> >        return;

> > is exactly that.

> > If we had that already, as Andrii argued in the original thread, we would have

> > never noticed this issue. <, >, <= ops would have worked, but == would be

> > sort-of working. It would mark one branch instead of both, and sometimes

> > neither of the branches. I'd rather have bugs like this one hurting and caught

> > quickly instead of warm feeling of being safe and sailing into unknown.

> 

> Agree. Although a WARN_ON_ONCE would have also been caught.


Right. Such WARN_ON_ONCE would definitely have been nice either in the caller
or in the callee. If I could have thought that id could be zero somehow here.
In retrospect it makes sense that there is possibility that IDs of regs in
this_branch and other_branch may diverge.
Hence I'm adding the warn to check for this specific divergence.
John Fastabend Oct. 15, 2020, 5:23 a.m. UTC | #6
Alexei Starovoitov wrote:
> On Wed, Oct 14, 2020 at 09:27:17PM -0700, John Fastabend wrote:
> > Alexei Starovoitov wrote:
> > > On Wed, Oct 14, 2020 at 09:04:23PM -0700, John Fastabend wrote:
> > > > Andrii Nakryiko wrote:
> > > > > On Wed, Oct 14, 2020 at 10:59 AM Alexei Starovoitov
> > > > > <alexei.starovoitov@gmail.com> wrote:
> > > > > >
> > > > > > From: Alexei Starovoitov <ast@kernel.org>
> > > > > >
> > > > > > The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either
> > > > > > true or false branch. In the case 'if (reg->id)' check was done on the other
> > > > > > branch the counter part register would have reg->id == 0 when called into
> > > > > > find_equal_scalars(). In such case the helper would incorrectly identify other
> > > > > > registers with id == 0 as equivalent and propagate the state incorrectly.
> > > > 
> > > > One thought. It seems we should never have reg->id=0 in find_equal_scalars()
> > > > would it be worthwhile to add an additional check here? Something like,
> > > > 
> > > >   if (known_reg->id == 0)
> > > > 	return
> > > >
> > > > Or even a WARN_ON_ONCE() there? Not sold either way, but maybe worth thinking
> > > > about.
> > > 
> > > That cannot happen anymore due to
> > > if (dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id))
> > > check in the caller.
> > > I prefer not to repeat the same check twice. Also I really don't like defensive programming.
> > > if (known_reg->id == 0)
> > >        return;
> > > is exactly that.
> > > If we had that already, as Andrii argued in the original thread, we would have
> > > never noticed this issue. <, >, <= ops would have worked, but == would be
> > > sort-of working. It would mark one branch instead of both, and sometimes
> > > neither of the branches. I'd rather have bugs like this one hurting and caught
> > > quickly instead of warm feeling of being safe and sailing into unknown.
> > 
> > Agree. Although a WARN_ON_ONCE would have also been caught.
> 
> Right. Such WARN_ON_ONCE would definitely have been nice either in the caller
> or in the callee. If I could have thought that id could be zero somehow here.
> In retrospect it makes sense that there is possibility that IDs of regs in
> this_branch and other_branch may diverge.
> Hence I'm adding the warn to check for this specific divergence.

LGTM thanks.
Yonghong Song Oct. 15, 2020, 5:46 a.m. UTC | #7
On 10/14/20 10:56 AM, Alexei Starovoitov wrote:
> From: Alexei Starovoitov <ast@kernel.org>

> 

> The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either

> true or false branch. In the case 'if (reg->id)' check was done on the other

> branch the counter part register would have reg->id == 0 when called into

> find_equal_scalars(). In such case the helper would incorrectly identify other

> registers with id == 0 as equivalent and propagate the state incorrectly.

> Fix it by preserving ID across reg_set_min_max().

> In other words any kind of comparison operator on the scalar register

> should preserve its ID to recognize:

> r1 = r2

> if (r1 == 20) {

>    #1 here both r1 and r2 == 20

> } else if (r2 < 20) {

>    #2 here both r1 and r2 < 20

> }

> 

> The patch is addressing #1 case. The #2 was working correctly already.

> 

> Fixes: 75748837b7e5 ("bpf: Propagate scalar ranges through register assignments.")

> Signed-off-by: Alexei Starovoitov <ast@kernel.org>


This fixed an issue appeared in our production system where packets may
be incorrectly dropped.

Test-by: Yonghong Song <yhs@fb.com>
Yonghong Song Oct. 15, 2020, 5:48 a.m. UTC | #8
On 10/14/20 10:46 PM, Yonghong Song wrote:
> 

> 

> On 10/14/20 10:56 AM, Alexei Starovoitov wrote:

>> From: Alexei Starovoitov <ast@kernel.org>

>>

>> The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id 

>> in either

>> true or false branch. In the case 'if (reg->id)' check was done on the 

>> other

>> branch the counter part register would have reg->id == 0 when called into

>> find_equal_scalars(). In such case the helper would incorrectly 

>> identify other

>> registers with id == 0 as equivalent and propagate the state incorrectly.

>> Fix it by preserving ID across reg_set_min_max().

>> In other words any kind of comparison operator on the scalar register

>> should preserve its ID to recognize:

>> r1 = r2

>> if (r1 == 20) {

>>    #1 here both r1 and r2 == 20

>> } else if (r2 < 20) {

>>    #2 here both r1 and r2 < 20

>> }

>>

>> The patch is addressing #1 case. The #2 was working correctly already.

>>

>> Fixes: 75748837b7e5 ("bpf: Propagate scalar ranges through register 

>> assignments.")

>> Signed-off-by: Alexei Starovoitov <ast@kernel.org>

> 

> This fixed an issue appeared in our production system where packets may

> be incorrectly dropped.

> 

> Test-by: Yonghong Song <yhs@fb.com>

Sorry for typo, it should be:
   Tested-by: Yonghong Song <yhs@fb.com>
patchwork-bot+netdevbpf@kernel.org Oct. 15, 2020, 2:10 p.m. UTC | #9
Hello:

This patch was applied to bpf/bpf-next.git (refs/heads/master):

On Wed, 14 Oct 2020 10:56:08 -0700 you wrote:
> From: Alexei Starovoitov <ast@kernel.org>

> 

> The 64-bit JEQ/JNE handling in reg_set_min_max() was clearing reg->id in either

> true or false branch. In the case 'if (reg->id)' check was done on the other

> branch the counter part register would have reg->id == 0 when called into

> find_equal_scalars(). In such case the helper would incorrectly identify other

> registers with id == 0 as equivalent and propagate the state incorrectly.

> Fix it by preserving ID across reg_set_min_max().

> In other words any kind of comparison operator on the scalar register

> should preserve its ID to recognize:

> r1 = r2

> if (r1 == 20) {

>   #1 here both r1 and r2 == 20

> } else if (r2 < 20) {

>   #2 here both r1 and r2 < 20

> }

> 

> [...]


Here is the summary with links:
  - [bpf-next] bpf: Fix register equivalence tracking.
    https://git.kernel.org/bpf/bpf-next/c/e688c3db7ca6

You are awesome, thank you!
--
Deet-doot-dot, I am a bot.
https://korg.docs.kernel.org/patchwork/pwbot.html
diff mbox series

Patch

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index c43a5e8f0818..39d7f44e7c92 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1010,14 +1010,9 @@  static const int caller_saved[CALLER_SAVED_REGS] = {
 static void __mark_reg_not_init(const struct bpf_verifier_env *env,
 				struct bpf_reg_state *reg);
 
-/* Mark the unknown part of a register (variable offset or scalar value) as
- * known to have the value @imm.
- */
-static void __mark_reg_known(struct bpf_reg_state *reg, u64 imm)
+/* This helper doesn't clear reg->id */
+static void ___mark_reg_known(struct bpf_reg_state *reg, u64 imm)
 {
-	/* Clear id, off, and union(map_ptr, range) */
-	memset(((u8 *)reg) + sizeof(reg->type), 0,
-	       offsetof(struct bpf_reg_state, var_off) - sizeof(reg->type));
 	reg->var_off = tnum_const(imm);
 	reg->smin_value = (s64)imm;
 	reg->smax_value = (s64)imm;
@@ -1030,6 +1025,17 @@  static void __mark_reg_known(struct bpf_reg_state *reg, u64 imm)
 	reg->u32_max_value = (u32)imm;
 }
 
+/* Mark the unknown part of a register (variable offset or scalar value) as
+ * known to have the value @imm.
+ */
+static void __mark_reg_known(struct bpf_reg_state *reg, u64 imm)
+{
+	/* Clear id, off, and union(map_ptr, range) */
+	memset(((u8 *)reg) + sizeof(reg->type), 0,
+	       offsetof(struct bpf_reg_state, var_off) - sizeof(reg->type));
+	___mark_reg_known(reg, imm);
+}
+
 static void __mark_reg32_known(struct bpf_reg_state *reg, u64 imm)
 {
 	reg->var_off = tnum_const_subreg(reg->var_off, imm);
@@ -7001,14 +7007,18 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		struct bpf_reg_state *reg =
 			opcode == BPF_JEQ ? true_reg : false_reg;
 
-		/* For BPF_JEQ, if this is false we know nothing Jon Snow, but
-		 * if it is true we know the value for sure. Likewise for
-		 * BPF_JNE.
+		/* JEQ/JNE comparison doesn't change the register equivalence.
+		 * r1 = r2;
+		 * if (r1 == 42) goto label;
+		 * ...
+		 * label: // here both r1 and r2 are known to be 42.
+		 *
+		 * Hence when marking register as known preserve it's ID.
 		 */
 		if (is_jmp32)
 			__mark_reg32_known(reg, val32);
 		else
-			__mark_reg_known(reg, val);
+			___mark_reg_known(reg, val);
 		break;
 	}
 	case BPF_JSET:
@@ -7551,7 +7561,8 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 				reg_combine_min_max(&other_branch_regs[insn->src_reg],
 						    &other_branch_regs[insn->dst_reg],
 						    src_reg, dst_reg, opcode);
-			if (src_reg->id) {
+			if (src_reg->id &&
+			    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
 				find_equal_scalars(this_branch, src_reg);
 				find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
 			}
@@ -7563,7 +7574,8 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 					opcode, is_jmp32);
 	}
 
-	if (dst_reg->type == SCALAR_VALUE && dst_reg->id) {
+	if (dst_reg->type == SCALAR_VALUE && dst_reg->id &&
+	    !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) {
 		find_equal_scalars(this_branch, dst_reg);
 		find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
 	}
diff --git a/tools/testing/selftests/bpf/verifier/regalloc.c b/tools/testing/selftests/bpf/verifier/regalloc.c
index ac71b824f97a..4ad7e05de706 100644
--- a/tools/testing/selftests/bpf/verifier/regalloc.c
+++ b/tools/testing/selftests/bpf/verifier/regalloc.c
@@ -241,3 +241,29 @@ 
 	.result = ACCEPT,
 	.prog_type = BPF_PROG_TYPE_TRACEPOINT,
 },
+{
+	"regalloc, spill, JEQ",
+	.insns = {
+	BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
+	BPF_ST_MEM(BPF_DW, BPF_REG_10, -8, 0),
+	BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
+	BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -8),
+	BPF_LD_MAP_FD(BPF_REG_1, 0),
+	BPF_EMIT_CALL(BPF_FUNC_map_lookup_elem),
+	BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_0, -8), /* spill r0 */
+	BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 0),
+	/* The verifier will walk the rest twice with r0 == 0 and r0 == map_value */
+	BPF_EMIT_CALL(BPF_FUNC_get_prandom_u32),
+	BPF_MOV64_REG(BPF_REG_2, BPF_REG_0),
+	BPF_JMP_IMM(BPF_JEQ, BPF_REG_2, 20, 0),
+	/* The verifier will walk the rest two more times with r0 == 20 and r0 == unknown */
+	BPF_LDX_MEM(BPF_DW, BPF_REG_3, BPF_REG_10, -8), /* fill r3 with map_value */
+	BPF_JMP_IMM(BPF_JEQ, BPF_REG_3, 0, 1), /* skip ldx if map_value == NULL */
+	/* Buggy verifier will think that r3 == 20 here */
+	BPF_LDX_MEM(BPF_DW, BPF_REG_0, BPF_REG_3, 0), /* read from map_value */
+	BPF_EXIT_INSN(),
+	},
+	.fixup_map_hash_48b = { 4 },
+	.result = ACCEPT,
+	.prog_type = BPF_PROG_TYPE_TRACEPOINT,
+},