diff mbox series

[bpf-next,1/3] bpf: Propagate scalar ranges through register assignments.

Message ID 20201006200955.12350-2-alexei.starovoitov@gmail.com
State New
Headers show
Series bpf: Make the verifier recognize llvm register allocation patterns. | expand

Commit Message

Alexei Starovoitov Oct. 6, 2020, 8:09 p.m. UTC
From: Alexei Starovoitov <ast@kernel.org>

The llvm register allocator may use two different registers representing the
same virtual register. In such case the following pattern can be observed:
1047: (bf) r9 = r6
1048: (a5) if r6 < 0x1000 goto pc+1
1050: ...
1051: (a5) if r9 < 0x2 goto pc+66
1052: ...
1053: (bf) r2 = r9 /* r2 needs to have upper and lower bounds */

In order to track this information without backtracking allocate ID
for scalars in a similar way as it's done for find_good_pkt_pointers().

When the verifier encounters r9 = r6 assignment it will assign the same ID
to both registers. Later if either register range is narrowed via conditional
jump propagate the register state into the other register.

Clear register ID in adjust_reg_min_max_vals() for any alu instruction.

Newly allocated register ID is ignored for scalars in regsafe() and doesn't
affect state pruning. mark_reg_unknown() also clears the ID.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
---
 kernel/bpf/verifier.c                         | 38 +++++++++++++++++++
 .../testing/selftests/bpf/prog_tests/align.c  | 16 ++++----
 .../bpf/verifier/direct_packet_access.c       |  2 +-
 3 files changed, 47 insertions(+), 9 deletions(-)

Comments

Andrii Nakryiko Oct. 7, 2020, 1:56 a.m. UTC | #1
On Tue, Oct 6, 2020 at 1:14 PM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>

> From: Alexei Starovoitov <ast@kernel.org>

>

> The llvm register allocator may use two different registers representing the

> same virtual register. In such case the following pattern can be observed:

> 1047: (bf) r9 = r6

> 1048: (a5) if r6 < 0x1000 goto pc+1

> 1050: ...

> 1051: (a5) if r9 < 0x2 goto pc+66

> 1052: ...

> 1053: (bf) r2 = r9 /* r2 needs to have upper and lower bounds */

>

> In order to track this information without backtracking allocate ID

> for scalars in a similar way as it's done for find_good_pkt_pointers().

>

> When the verifier encounters r9 = r6 assignment it will assign the same ID

> to both registers. Later if either register range is narrowed via conditional

> jump propagate the register state into the other register.

>

> Clear register ID in adjust_reg_min_max_vals() for any alu instruction.

>

> Newly allocated register ID is ignored for scalars in regsafe() and doesn't

> affect state pruning. mark_reg_unknown() also clears the ID.

>

> Signed-off-by: Alexei Starovoitov <ast@kernel.org>

> ---


I couldn't find the problem with the logic, though it's quite
non-obvious at times that reg->id will be cleared on BPF_END/BPF_NEG
and few other operations. But I think naming of this function can be
improved, see below.

Also, profiler.c is great, but it would still be nice to add selftest
to test_verifier that will explicitly test the logic in this patch

>  kernel/bpf/verifier.c                         | 38 +++++++++++++++++++

>  .../testing/selftests/bpf/prog_tests/align.c  | 16 ++++----

>  .../bpf/verifier/direct_packet_access.c       |  2 +-

>  3 files changed, 47 insertions(+), 9 deletions(-)

>

> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c

> index 01120acab09a..09e17b483b0b 100644

> --- a/kernel/bpf/verifier.c

> +++ b/kernel/bpf/verifier.c

> @@ -6432,6 +6432,8 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,

>         src_reg = NULL;

>         if (dst_reg->type != SCALAR_VALUE)

>                 ptr_reg = dst_reg;

> +       else

> +               dst_reg->id = 0;

>         if (BPF_SRC(insn->code) == BPF_X) {

>                 src_reg = &regs[insn->src_reg];

>                 if (src_reg->type != SCALAR_VALUE) {

> @@ -6565,6 +6567,8 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)

>                                 /* case: R1 = R2

>                                  * copy register state to dest reg

>                                  */

> +                               if (src_reg->type == SCALAR_VALUE)

> +                                       src_reg->id = ++env->id_gen;

>                                 *dst_reg = *src_reg;

>                                 dst_reg->live |= REG_LIVE_WRITTEN;

>                                 dst_reg->subreg_def = DEF_NOT_SUBREG;

> @@ -7365,6 +7369,30 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,

>         return true;

>  }

>

> +static void find_equal_scalars(struct bpf_verifier_state *vstate,

> +                              struct bpf_reg_state *known_reg)


this is double-misleading name:

1) it's not just "find", but also "update" (or rather the purpose of
this function is specifically to update registers, not find them, as
we don't really return found register)
2) "equal" is not exactly true either. You can have two scalar
register with exactly the same state, but they might not share ->id.
So it's less about being equal, rather being "linked" by assignment.

> +{

> +       struct bpf_func_state *state;

> +       struct bpf_reg_state *reg;

> +       int i, j;

> +

> +       for (i = 0; i <= vstate->curframe; i++) {

> +               state = vstate->frame[i];

> +               for (j = 0; j < MAX_BPF_REG; j++) {

> +                       reg = &state->regs[j];

> +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> +                               *reg = *known_reg;

> +               }

> +

> +               bpf_for_each_spilled_reg(j, state, reg) {

> +                       if (!reg)

> +                               continue;

> +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> +                               *reg = *known_reg;

> +               }

> +       }

> +}

> +

>  static int check_cond_jmp_op(struct bpf_verifier_env *env,

>                              struct bpf_insn *insn, int *insn_idx)

>  {

> @@ -7493,6 +7521,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

>                                 reg_combine_min_max(&other_branch_regs[insn->src_reg],

>                                                     &other_branch_regs[insn->dst_reg],

>                                                     src_reg, dst_reg, opcode);

> +                       if (src_reg->id) {

> +                               find_equal_scalars(this_branch, src_reg);

> +                               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);

> +                       }

> +

>                 }

>         } else if (dst_reg->type == SCALAR_VALUE) {

>                 reg_set_min_max(&other_branch_regs[insn->dst_reg],

> @@ -7500,6 +7533,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

>                                         opcode, is_jmp32);

>         }

>

> +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id) {

> +               find_equal_scalars(this_branch, dst_reg);

> +               find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);


will this cover the case above where reg_combine_min_max() can update
dst_reg's as well? Even if yes, it probably would be more
straightforward to call appropriate updates in the respective if
branches (it's just a single line for each register, so not like it's
duplicating tons of code). It will make reasoning about this logic
easier, IMO. Also, moving reg->id check into find_equal_scalars()
would make the above suggestion even cleaner.

> +       }

> +

>         /* detect if R == 0 where R is returned from bpf_map_lookup_elem().

>          * NOTE: these optimizations below are related with pointer comparison

>          *       which will never be JMP32.


[...]
Alexei Starovoitov Oct. 7, 2020, 2:18 a.m. UTC | #2
On Tue, Oct 06, 2020 at 06:56:14PM -0700, Andrii Nakryiko wrote:
> On Tue, Oct 6, 2020 at 1:14 PM Alexei Starovoitov

> <alexei.starovoitov@gmail.com> wrote:

> >

> > From: Alexei Starovoitov <ast@kernel.org>

> >

> > The llvm register allocator may use two different registers representing the

> > same virtual register. In such case the following pattern can be observed:

> > 1047: (bf) r9 = r6

> > 1048: (a5) if r6 < 0x1000 goto pc+1

> > 1050: ...

> > 1051: (a5) if r9 < 0x2 goto pc+66

> > 1052: ...

> > 1053: (bf) r2 = r9 /* r2 needs to have upper and lower bounds */

> >

> > In order to track this information without backtracking allocate ID

> > for scalars in a similar way as it's done for find_good_pkt_pointers().

> >

> > When the verifier encounters r9 = r6 assignment it will assign the same ID

> > to both registers. Later if either register range is narrowed via conditional

> > jump propagate the register state into the other register.

> >

> > Clear register ID in adjust_reg_min_max_vals() for any alu instruction.

> >

> > Newly allocated register ID is ignored for scalars in regsafe() and doesn't

> > affect state pruning. mark_reg_unknown() also clears the ID.

> >

> > Signed-off-by: Alexei Starovoitov <ast@kernel.org>

> > ---

> 

> I couldn't find the problem with the logic, though it's quite

> non-obvious at times that reg->id will be cleared on BPF_END/BPF_NEG

> and few other operations. But I think naming of this function can be

> improved, see below.

> 

> Also, profiler.c is great, but it would still be nice to add selftest

> to test_verifier that will explicitly test the logic in this patch


the test align.c actualy does the id checking better than I expected.
I'm planning to add more asm tests in the follow up.

> >  kernel/bpf/verifier.c                         | 38 +++++++++++++++++++

> >  .../testing/selftests/bpf/prog_tests/align.c  | 16 ++++----

> >  .../bpf/verifier/direct_packet_access.c       |  2 +-

> >  3 files changed, 47 insertions(+), 9 deletions(-)

> >

> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c

> > index 01120acab09a..09e17b483b0b 100644

> > --- a/kernel/bpf/verifier.c

> > +++ b/kernel/bpf/verifier.c

> > @@ -6432,6 +6432,8 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,

> >         src_reg = NULL;

> >         if (dst_reg->type != SCALAR_VALUE)

> >                 ptr_reg = dst_reg;

> > +       else

> > +               dst_reg->id = 0;

> >         if (BPF_SRC(insn->code) == BPF_X) {

> >                 src_reg = &regs[insn->src_reg];

> >                 if (src_reg->type != SCALAR_VALUE) {

> > @@ -6565,6 +6567,8 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)

> >                                 /* case: R1 = R2

> >                                  * copy register state to dest reg

> >                                  */

> > +                               if (src_reg->type == SCALAR_VALUE)

> > +                                       src_reg->id = ++env->id_gen;

> >                                 *dst_reg = *src_reg;

> >                                 dst_reg->live |= REG_LIVE_WRITTEN;

> >                                 dst_reg->subreg_def = DEF_NOT_SUBREG;

> > @@ -7365,6 +7369,30 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,

> >         return true;

> >  }

> >

> > +static void find_equal_scalars(struct bpf_verifier_state *vstate,

> > +                              struct bpf_reg_state *known_reg)

> 

> this is double-misleading name:

> 

> 1) it's not just "find", but also "update" (or rather the purpose of

> this function is specifically to update registers, not find them, as

> we don't really return found register)

> 2) "equal" is not exactly true either. You can have two scalar

> register with exactly the same state, but they might not share ->id.

> So it's less about being equal, rather being "linked" by assignment.


I don't think I can agree.
We already have find_good_pkt_pointers() that also updates,
so 'find' fits better than 'update'.
'linked' is also wrong. The regs are exactly equal.
In case of pkt and other pointers two regs will have the same id
as well, but they will not be equal. Here these two scalars are equal
otherwise doing *reg = *known_reg would be wrong.

> > +{

> > +       struct bpf_func_state *state;

> > +       struct bpf_reg_state *reg;

> > +       int i, j;

> > +

> > +       for (i = 0; i <= vstate->curframe; i++) {

> > +               state = vstate->frame[i];

> > +               for (j = 0; j < MAX_BPF_REG; j++) {

> > +                       reg = &state->regs[j];

> > +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> > +                               *reg = *known_reg;

> > +               }

> > +

> > +               bpf_for_each_spilled_reg(j, state, reg) {

> > +                       if (!reg)

> > +                               continue;

> > +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> > +                               *reg = *known_reg;

> > +               }

> > +       }

> > +}

> > +

> >  static int check_cond_jmp_op(struct bpf_verifier_env *env,

> >                              struct bpf_insn *insn, int *insn_idx)

> >  {

> > @@ -7493,6 +7521,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

> >                                 reg_combine_min_max(&other_branch_regs[insn->src_reg],

> >                                                     &other_branch_regs[insn->dst_reg],

> >                                                     src_reg, dst_reg, opcode);

> > +                       if (src_reg->id) {

> > +                               find_equal_scalars(this_branch, src_reg);

> > +                               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);

> > +                       }

> > +

> >                 }

> >         } else if (dst_reg->type == SCALAR_VALUE) {

> >                 reg_set_min_max(&other_branch_regs[insn->dst_reg],

> > @@ -7500,6 +7533,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

> >                                         opcode, is_jmp32);

> >         }

> >

> > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id) {

> > +               find_equal_scalars(this_branch, dst_reg);

> > +               find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);

> 

> will this cover the case above where reg_combine_min_max() can update

> dst_reg's as well? 


yes.

> Even if yes, it probably would be more

> straightforward to call appropriate updates in the respective if

> branches (it's just a single line for each register, so not like it's

> duplicating tons of code). 


You mean inside reg_set_min_max() and inside reg_combine_min_max() ?
That won't work because find_equal_scalars() needs access to the whole
bpf_verifier_state and not just bpf_reg_state.

> It will make reasoning about this logic

> easier, IMO. Also, moving reg->id check into find_equal_scalars()

> would make the above suggestion even cleaner.


I don't think so. I think checking for type == SCALAR && dst_reg->id != 0
should be done outside of that function. It makes the logic cleaner.
For the same reason we check type outside of find_good_pkt_pointers().
Andrii Nakryiko Oct. 7, 2020, 3:31 a.m. UTC | #3
On Tue, Oct 6, 2020 at 7:18 PM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>

> On Tue, Oct 06, 2020 at 06:56:14PM -0700, Andrii Nakryiko wrote:

> > On Tue, Oct 6, 2020 at 1:14 PM Alexei Starovoitov

> > <alexei.starovoitov@gmail.com> wrote:

> > >

> > > From: Alexei Starovoitov <ast@kernel.org>

> > >

> > > The llvm register allocator may use two different registers representing the

> > > same virtual register. In such case the following pattern can be observed:

> > > 1047: (bf) r9 = r6

> > > 1048: (a5) if r6 < 0x1000 goto pc+1

> > > 1050: ...

> > > 1051: (a5) if r9 < 0x2 goto pc+66

> > > 1052: ...

> > > 1053: (bf) r2 = r9 /* r2 needs to have upper and lower bounds */

> > >

> > > In order to track this information without backtracking allocate ID

> > > for scalars in a similar way as it's done for find_good_pkt_pointers().

> > >

> > > When the verifier encounters r9 = r6 assignment it will assign the same ID

> > > to both registers. Later if either register range is narrowed via conditional

> > > jump propagate the register state into the other register.

> > >

> > > Clear register ID in adjust_reg_min_max_vals() for any alu instruction.

> > >

> > > Newly allocated register ID is ignored for scalars in regsafe() and doesn't

> > > affect state pruning. mark_reg_unknown() also clears the ID.

> > >

> > > Signed-off-by: Alexei Starovoitov <ast@kernel.org>

> > > ---

> >

> > I couldn't find the problem with the logic, though it's quite

> > non-obvious at times that reg->id will be cleared on BPF_END/BPF_NEG

> > and few other operations. But I think naming of this function can be

> > improved, see below.

> >

> > Also, profiler.c is great, but it would still be nice to add selftest

> > to test_verifier that will explicitly test the logic in this patch

>

> the test align.c actualy does the id checking better than I expected.

> I'm planning to add more asm tests in the follow up.

>


ok

Acked-by: Andrii Nakryiko <andrii@kernel.org>


> > >  kernel/bpf/verifier.c                         | 38 +++++++++++++++++++

> > >  .../testing/selftests/bpf/prog_tests/align.c  | 16 ++++----

> > >  .../bpf/verifier/direct_packet_access.c       |  2 +-

> > >  3 files changed, 47 insertions(+), 9 deletions(-)

> > >

> > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c

> > > index 01120acab09a..09e17b483b0b 100644

> > > --- a/kernel/bpf/verifier.c

> > > +++ b/kernel/bpf/verifier.c

> > > @@ -6432,6 +6432,8 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,

> > >         src_reg = NULL;

> > >         if (dst_reg->type != SCALAR_VALUE)

> > >                 ptr_reg = dst_reg;

> > > +       else

> > > +               dst_reg->id = 0;

> > >         if (BPF_SRC(insn->code) == BPF_X) {

> > >                 src_reg = &regs[insn->src_reg];

> > >                 if (src_reg->type != SCALAR_VALUE) {

> > > @@ -6565,6 +6567,8 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)

> > >                                 /* case: R1 = R2

> > >                                  * copy register state to dest reg

> > >                                  */

> > > +                               if (src_reg->type == SCALAR_VALUE)

> > > +                                       src_reg->id = ++env->id_gen;

> > >                                 *dst_reg = *src_reg;

> > >                                 dst_reg->live |= REG_LIVE_WRITTEN;

> > >                                 dst_reg->subreg_def = DEF_NOT_SUBREG;

> > > @@ -7365,6 +7369,30 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,

> > >         return true;

> > >  }

> > >

> > > +static void find_equal_scalars(struct bpf_verifier_state *vstate,

> > > +                              struct bpf_reg_state *known_reg)

> >

> > this is double-misleading name:

> >

> > 1) it's not just "find", but also "update" (or rather the purpose of

> > this function is specifically to update registers, not find them, as

> > we don't really return found register)

> > 2) "equal" is not exactly true either. You can have two scalar

> > register with exactly the same state, but they might not share ->id.

> > So it's less about being equal, rather being "linked" by assignment.

>

> I don't think I can agree.

> We already have find_good_pkt_pointers() that also updates,

> so 'find' fits better than 'update'.


find_good_pkt_pointers() has similarly confusing name, but sure,
consistency rules

> 'linked' is also wrong. The regs are exactly equal.

> In case of pkt and other pointers two regs will have the same id

> as well, but they will not be equal. Here these two scalars are equal

> otherwise doing *reg = *known_reg would be wrong.


Ok, I guess it also means that "reg->type == SCALAR_VALUE" checks
below are unnecessary as well, because if known_reg->id matches, that
means register states are exactly the same.

>

> > > +{

> > > +       struct bpf_func_state *state;

> > > +       struct bpf_reg_state *reg;

> > > +       int i, j;

> > > +

> > > +       for (i = 0; i <= vstate->curframe; i++) {

> > > +               state = vstate->frame[i];

> > > +               for (j = 0; j < MAX_BPF_REG; j++) {

> > > +                       reg = &state->regs[j];

> > > +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> > > +                               *reg = *known_reg;

> > > +               }

> > > +

> > > +               bpf_for_each_spilled_reg(j, state, reg) {

> > > +                       if (!reg)

> > > +                               continue;

> > > +                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)

> > > +                               *reg = *known_reg;

> > > +               }

> > > +       }

> > > +}

> > > +

> > >  static int check_cond_jmp_op(struct bpf_verifier_env *env,

> > >                              struct bpf_insn *insn, int *insn_idx)

> > >  {

> > > @@ -7493,6 +7521,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

> > >                                 reg_combine_min_max(&other_branch_regs[insn->src_reg],

> > >                                                     &other_branch_regs[insn->dst_reg],

> > >                                                     src_reg, dst_reg, opcode);

> > > +                       if (src_reg->id) {

> > > +                               find_equal_scalars(this_branch, src_reg);

> > > +                               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);

> > > +                       }

> > > +

> > >                 }

> > >         } else if (dst_reg->type == SCALAR_VALUE) {

> > >                 reg_set_min_max(&other_branch_regs[insn->dst_reg],

> > > @@ -7500,6 +7533,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,

> > >                                         opcode, is_jmp32);

> > >         }

> > >

> > > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id) {

> > > +               find_equal_scalars(this_branch, dst_reg);

> > > +               find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);

> >

> > will this cover the case above where reg_combine_min_max() can update

> > dst_reg's as well?

>

> yes.

>

> > Even if yes, it probably would be more

> > straightforward to call appropriate updates in the respective if

> > branches (it's just a single line for each register, so not like it's

> > duplicating tons of code).

>

> You mean inside reg_set_min_max() and inside reg_combine_min_max() ?

> That won't work because find_equal_scalars() needs access to the whole

> bpf_verifier_state and not just bpf_reg_state.


No, I meant something like this, few lines above:

if (BPF_SRC(insn->code) == BPF_X) {

    if (dst_reg->type == SCALAR_VALUE && src_reg->type == SCALAR_VALUE) {
        if (...)
        else if (...)
        else

        /* both src/dst regs in both this/other branches could have
been updated */
        find_equal_scalars(this_branch, src_reg);
        find_equal_scalars(this_branch, dst_reg);
        find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg])
        find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg])
    }
} else if (dst_reg->type == SCALAR_VALUE) {
    reg_set_min_max(...);

    /* only dst_reg in both branches could have been updated */
    find_equal_scalars(this_branch, dst_reg);
    find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
}


This keeps find_equal_scalars() for relevant registers very close to
places where those registers are updated, instead of jumping back and
forth between the complicated if  after it, and double-checking under
which circumstances dst_reg can be updated, for example.

>

> > It will make reasoning about this logic

> > easier, IMO. Also, moving reg->id check into find_equal_scalars()

> > would make the above suggestion even cleaner.

>

> I don't think so. I think checking for type == SCALAR && dst_reg->id != 0

> should be done outside of that function. It makes the logic cleaner.

> For the same reason we check type outside of find_good_pkt_pointers().
John Fastabend Oct. 7, 2020, 11:44 p.m. UTC | #4
Alexei Starovoitov wrote:
> From: Alexei Starovoitov <ast@kernel.org>

> 

> The llvm register allocator may use two different registers representing the

> same virtual register. In such case the following pattern can be observed:

> 1047: (bf) r9 = r6

> 1048: (a5) if r6 < 0x1000 goto pc+1

> 1050: ...

> 1051: (a5) if r9 < 0x2 goto pc+66

> 1052: ...

> 1053: (bf) r2 = r9 /* r2 needs to have upper and lower bounds */

> 

> In order to track this information without backtracking allocate ID

> for scalars in a similar way as it's done for find_good_pkt_pointers().

> 

> When the verifier encounters r9 = r6 assignment it will assign the same ID

> to both registers. Later if either register range is narrowed via conditional

> jump propagate the register state into the other register.

> 

> Clear register ID in adjust_reg_min_max_vals() for any alu instruction.


Do we also need to clear the register ID on reg0 for CALL ops into a
helper?

Looks like check_helper_call might mark reg0 as a scalar, but I don't
see where it would clear the reg->id? Did I miss it. Either way maybe
a comment here would help make it obvious how CALLs are handled?

Thanks,
John

> 

> Newly allocated register ID is ignored for scalars in regsafe() and doesn't

> affect state pruning. mark_reg_unknown() also clears the ID.

> 

> Signed-off-by: Alexei Starovoitov <ast@kernel.org>

> ---

>  kernel/bpf/verifier.c                         | 38 +++++++++++++++++++

>  .../testing/selftests/bpf/prog_tests/align.c  | 16 ++++----

>  .../bpf/verifier/direct_packet_access.c       |  2 +-

>  3 files changed, 47 insertions(+), 9 deletions(-)

> 

> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c

> index 01120acab09a..09e17b483b0b 100644

> --- a/kernel/bpf/verifier.c

> +++ b/kernel/bpf/verifier.c

> @@ -6432,6 +6432,8 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,

>  	src_reg = NULL;

>  	if (dst_reg->type != SCALAR_VALUE)

>  		ptr_reg = dst_reg;

> +	else

> +		dst_reg->id = 0;

>  	if (BPF_SRC(insn->code) == BPF_X) {

>  		src_reg = &regs[insn->src_reg];

>  		if (src_reg->type != SCALAR_VALUE) {

> @@ -6565,6 +6567,8 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)

>  				/* case: R1 = R2

>  				 * copy register state to dest reg

>  				 */

> +				if (src_reg->type == SCALAR_VALUE)

> +					src_reg->id = ++env->id_gen;

>  				*dst_reg = *src_reg;

>  				dst_reg->live |= REG_LIVE_WRITTEN;

>  				dst_reg->subreg_def = DEF_NOT_SUBREG;

> @@ -7365,6 +7369,30 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,

>  	return true;

>  }
diff mbox series

Patch

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 01120acab09a..09e17b483b0b 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -6432,6 +6432,8 @@  static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 	src_reg = NULL;
 	if (dst_reg->type != SCALAR_VALUE)
 		ptr_reg = dst_reg;
+	else
+		dst_reg->id = 0;
 	if (BPF_SRC(insn->code) == BPF_X) {
 		src_reg = &regs[insn->src_reg];
 		if (src_reg->type != SCALAR_VALUE) {
@@ -6565,6 +6567,8 @@  static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
 				/* case: R1 = R2
 				 * copy register state to dest reg
 				 */
+				if (src_reg->type == SCALAR_VALUE)
+					src_reg->id = ++env->id_gen;
 				*dst_reg = *src_reg;
 				dst_reg->live |= REG_LIVE_WRITTEN;
 				dst_reg->subreg_def = DEF_NOT_SUBREG;
@@ -7365,6 +7369,30 @@  static bool try_match_pkt_pointers(const struct bpf_insn *insn,
 	return true;
 }
 
+static void find_equal_scalars(struct bpf_verifier_state *vstate,
+			       struct bpf_reg_state *known_reg)
+{
+	struct bpf_func_state *state;
+	struct bpf_reg_state *reg;
+	int i, j;
+
+	for (i = 0; i <= vstate->curframe; i++) {
+		state = vstate->frame[i];
+		for (j = 0; j < MAX_BPF_REG; j++) {
+			reg = &state->regs[j];
+			if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+				*reg = *known_reg;
+		}
+
+		bpf_for_each_spilled_reg(j, state, reg) {
+			if (!reg)
+				continue;
+			if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+				*reg = *known_reg;
+		}
+	}
+}
+
 static int check_cond_jmp_op(struct bpf_verifier_env *env,
 			     struct bpf_insn *insn, int *insn_idx)
 {
@@ -7493,6 +7521,11 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 				reg_combine_min_max(&other_branch_regs[insn->src_reg],
 						    &other_branch_regs[insn->dst_reg],
 						    src_reg, dst_reg, opcode);
+			if (src_reg->id) {
+				find_equal_scalars(this_branch, src_reg);
+				find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
+			}
+
 		}
 	} else if (dst_reg->type == SCALAR_VALUE) {
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
@@ -7500,6 +7533,11 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 					opcode, is_jmp32);
 	}
 
+	if (dst_reg->type == SCALAR_VALUE && dst_reg->id) {
+		find_equal_scalars(this_branch, dst_reg);
+		find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
+	}
+
 	/* detect if R == 0 where R is returned from bpf_map_lookup_elem().
 	 * NOTE: these optimizations below are related with pointer comparison
 	 *       which will never be JMP32.
diff --git a/tools/testing/selftests/bpf/prog_tests/align.c b/tools/testing/selftests/bpf/prog_tests/align.c
index c548aded6585..56a414ce5504 100644
--- a/tools/testing/selftests/bpf/prog_tests/align.c
+++ b/tools/testing/selftests/bpf/prog_tests/align.c
@@ -195,13 +195,13 @@  static struct bpf_align_test tests[] = {
 		.prog_type = BPF_PROG_TYPE_SCHED_CLS,
 		.matches = {
 			{7, "R3_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
-			{8, "R4_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
+			{8, "R4_w=inv(id=1,umax_value=255,var_off=(0x0; 0xff))"},
 			{9, "R4_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
-			{10, "R4_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
+			{10, "R4_w=inv(id=2,umax_value=255,var_off=(0x0; 0xff))"},
 			{11, "R4_w=inv(id=0,umax_value=510,var_off=(0x0; 0x1fe))"},
-			{12, "R4_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
+			{12, "R4_w=inv(id=3,umax_value=255,var_off=(0x0; 0xff))"},
 			{13, "R4_w=inv(id=0,umax_value=1020,var_off=(0x0; 0x3fc))"},
-			{14, "R4_w=inv(id=0,umax_value=255,var_off=(0x0; 0xff))"},
+			{14, "R4_w=inv(id=4,umax_value=255,var_off=(0x0; 0xff))"},
 			{15, "R4_w=inv(id=0,umax_value=2040,var_off=(0x0; 0x7f8))"},
 			{16, "R4_w=inv(id=0,umax_value=4080,var_off=(0x0; 0xff0))"},
 		},
@@ -518,7 +518,7 @@  static struct bpf_align_test tests[] = {
 			 * the total offset is 4-byte aligned and meets the
 			 * load's requirements.
 			 */
-			{20, "R5=pkt(id=1,off=0,r=4,umin_value=2,umax_value=1034,var_off=(0x2; 0x7fc)"},
+			{20, "R5=pkt(id=2,off=0,r=4,umin_value=2,umax_value=1034,var_off=(0x2; 0x7fc)"},
 
 		},
 	},
@@ -561,18 +561,18 @@  static struct bpf_align_test tests[] = {
 			/* Adding 14 makes R6 be (4n+2) */
 			{11, "R6_w=inv(id=0,umin_value=14,umax_value=74,var_off=(0x2; 0x7c))"},
 			/* Subtracting from packet pointer overflows ubounds */
-			{13, "R5_w=pkt(id=1,off=0,r=8,umin_value=18446744073709551542,umax_value=18446744073709551602,var_off=(0xffffffffffffff82; 0x7c)"},
+			{13, "R5_w=pkt(id=2,off=0,r=8,umin_value=18446744073709551542,umax_value=18446744073709551602,var_off=(0xffffffffffffff82; 0x7c)"},
 			/* New unknown value in R7 is (4n), >= 76 */
 			{15, "R7_w=inv(id=0,umin_value=76,umax_value=1096,var_off=(0x0; 0x7fc))"},
 			/* Adding it to packet pointer gives nice bounds again */
-			{16, "R5_w=pkt(id=2,off=0,r=0,umin_value=2,umax_value=1082,var_off=(0x2; 0xfffffffc)"},
+			{16, "R5_w=pkt(id=3,off=0,r=0,umin_value=2,umax_value=1082,var_off=(0x2; 0xfffffffc)"},
 			/* At the time the word size load is performed from R5,
 			 * its total fixed offset is NET_IP_ALIGN + reg->off (0)
 			 * which is 2.  Then the variable offset is (4n+2), so
 			 * the total offset is 4-byte aligned and meets the
 			 * load's requirements.
 			 */
-			{20, "R5=pkt(id=2,off=0,r=4,umin_value=2,umax_value=1082,var_off=(0x2; 0xfffffffc)"},
+			{20, "R5=pkt(id=3,off=0,r=4,umin_value=2,umax_value=1082,var_off=(0x2; 0xfffffffc)"},
 		},
 	},
 };
diff --git a/tools/testing/selftests/bpf/verifier/direct_packet_access.c b/tools/testing/selftests/bpf/verifier/direct_packet_access.c
index 2c5fbe7bcd27..ae72536603fe 100644
--- a/tools/testing/selftests/bpf/verifier/direct_packet_access.c
+++ b/tools/testing/selftests/bpf/verifier/direct_packet_access.c
@@ -529,7 +529,7 @@ 
 	},
 	.prog_type = BPF_PROG_TYPE_SCHED_CLS,
 	.result = REJECT,
-	.errstr = "invalid access to packet, off=0 size=8, R5(id=1,off=0,r=0)",
+	.errstr = "invalid access to packet, off=0 size=8, R5(id=2,off=0,r=0)",
 	.flags = F_NEEDS_EFFICIENT_UNALIGNED_ACCESS,
 },
 {