diff mbox series

Support fused multiply-adds in fully-masked reductions

Message ID 87zi10neb6.fsf@linaro.org
State New
Headers show
Series Support fused multiply-adds in fully-masked reductions | expand

Commit Message

Richard Sandiford May 16, 2018, 9:24 a.m. UTC
This patch adds support for fusing a conditional add or subtract
with a multiplication, so that we can use fused multiply-add and
multiply-subtract operations for fully-masked reductions.  E.g.
for SVE we vectorise:

  double res = 0.0;
  for (int i = 0; i < n; ++i)
    res += x[i] * y[i];

using a fully-masked loop in which the loop body has the form:

  res_1 = PHI<0(preheader), res_2(latch)>;
  avec = IFN_MASK_LOAD (loop_mask, a)
  bvec = IFN_MASK_LOAD (loop_mask, b)
  prod = avec * bvec;
  res_2 = IFN_COND_ADD (loop_mask, res_1, prod);

where the last statement does the equivalent of:

  res_2 = loop_mask ? res_1 + prod : res_1;

(operating elementwise).  The point of the patch is to convert the last
two statements into a single internal function that is the equivalent of:

  res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;

(again operating elementwise).

All current conditional X operations have the form "do X or don't do X
to the first operand" (add/don't add to first operand, etc.).  However,
the FMA optabs and functions are ordered so that the accumulator comes
last.  There were two obvious ways of resolving this: break the
convention for conditional operators and have "add/don't add to the
final operand" or break the convention for FMA and put the accumulator
first.  The patch goes for the latter, but adds _REV to make it obvious
that the operands are in a different order.

Tested on aarch64-linux-gnu (with and without SVE), aarch64_be-elf
and x86_64-linux-gnu.  OK to install?

Richard


2018-05-16  Richard Sandiford  <richard.sandiford@linaro.org>
	    Alan Hayward  <alan.hayward@arm.com>
	    David Sherwood  <david.sherwood@arm.com>

gcc/
	* doc/md.texi (cond_fma_rev, cond_fnma_rev): Document.
	* optabs.def (cond_fma_rev, cond_fnma_rev): New optabs.
	* internal-fn.def (COND_FMA_REV, COND_FNMA_REV): New internal
	functions.
	* internal-fn.h (can_interpret_as_conditional_op_p): Declare.
	* internal-fn.c (cond_ternary_direct): New macro.
	(expand_cond_ternary_optab_fn): Likewise.
	(direct_cond_ternary_optab_supported_p): Likewise.
	(FOR_EACH_CODE_MAPPING): Likewise.
	(get_conditional_internal_fn): Use FOR_EACH_CODE_MAPPING.
	(conditional_internal_fn_code): New function.
	(can_interpret_as_conditional_op_p): Likewise.
	* tree-ssa-math-opts.c (fused_cond_internal_fn): New function.
	(convert_mult_to_fma_1): Transform calls to IFN_COND_ADD to
	IFN_COND_FMA_REV and calls to IFN_COND_SUB to IFN_COND_FNMA_REV.
	(convert_mult_to_fma): Handle calls to IFN_COND_ADD and IFN_COND_SUB.
	* genmatch.c (commutative_op): Handle CFN_COND_FMA_REV and
	CFN_COND_FNMA_REV.
	* config/aarch64/iterators.md (UNSPEC_COND_FMLA): New unspec.
	(UNSPEC_COND_FMLS): Likewise.
	(optab, sve_fp_op): Handle them.
	(SVE_COND_INT_OP): Rename to...
	(SVE_COND_INT2_OP): ...this.
	(SVE_COND_FP_OP): Rename to...
	(SVE_COND_FP2_OP): ...this.
	(SVE_COND_FP3_OP): New iterator.
	* config/aarch64/aarch64-sve.md (cond_<optab><mode>): Update
	for new iterator names.  Add a pattern for SVE_COND_FP3_OP.

gcc/testsuite/
	* gcc.target/aarch64/sve/reduc_4.c: New test.
	* gcc.target/aarch64/sve/reduc_6.c: Likewise.
	* gcc.target/aarch64/sve/reduc_7.c: Likewise.

Comments

Richard Biener May 24, 2018, 10:30 a.m. UTC | #1
On Wed, May 16, 2018 at 11:26 AM Richard Sandiford <
richard.sandiford@linaro.org> wrote:

> This patch adds support for fusing a conditional add or subtract

> with a multiplication, so that we can use fused multiply-add and

> multiply-subtract operations for fully-masked reductions.  E.g.

> for SVE we vectorise:


>    double res = 0.0;

>    for (int i = 0; i < n; ++i)

>      res += x[i] * y[i];


> using a fully-masked loop in which the loop body has the form:


>    res_1 = PHI<0(preheader), res_2(latch)>;

>    avec = IFN_MASK_LOAD (loop_mask, a)

>    bvec = IFN_MASK_LOAD (loop_mask, b)

>    prod = avec * bvec;

>    res_2 = IFN_COND_ADD (loop_mask, res_1, prod);


> where the last statement does the equivalent of:


>    res_2 = loop_mask ? res_1 + prod : res_1;


> (operating elementwise).  The point of the patch is to convert the last

> two statements into a single internal function that is the equivalent of:


>    res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;


> (again operating elementwise).


> All current conditional X operations have the form "do X or don't do X

> to the first operand" (add/don't add to first operand, etc.).  However,

> the FMA optabs and functions are ordered so that the accumulator comes

> last.  There were two obvious ways of resolving this: break the

> convention for conditional operators and have "add/don't add to the

> final operand" or break the convention for FMA and put the accumulator

> first.  The patch goes for the latter, but adds _REV to make it obvious

> that the operands are in a different order.


Eh.  I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM?

That said, I don't really see the "do or not do to the first operand", it's
"do or not do the operation on operands 1 to 2 (or 3)".  None of the
current ops modify operand 1, they all produce a new value, no?

> Tested on aarch64-linux-gnu (with and without SVE), aarch64_be-elf

> and x86_64-linux-gnu.  OK to install?


OK, but as said I don't see a reason for the operand order to differ
in the first place.

Richard.

> Richard



> 2018-05-16  Richard Sandiford  <richard.sandiford@linaro.org>

>              Alan Hayward  <alan.hayward@arm.com>

>              David Sherwood  <david.sherwood@arm.com>


> gcc/

>          * doc/md.texi (cond_fma_rev, cond_fnma_rev): Document.

>          * optabs.def (cond_fma_rev, cond_fnma_rev): New optabs.

>          * internal-fn.def (COND_FMA_REV, COND_FNMA_REV): New internal

>          functions.

>          * internal-fn.h (can_interpret_as_conditional_op_p): Declare.

>          * internal-fn.c (cond_ternary_direct): New macro.

>          (expand_cond_ternary_optab_fn): Likewise.

>          (direct_cond_ternary_optab_supported_p): Likewise.

>          (FOR_EACH_CODE_MAPPING): Likewise.

>          (get_conditional_internal_fn): Use FOR_EACH_CODE_MAPPING.

>          (conditional_internal_fn_code): New function.

>          (can_interpret_as_conditional_op_p): Likewise.

>          * tree-ssa-math-opts.c (fused_cond_internal_fn): New function.

>          (convert_mult_to_fma_1): Transform calls to IFN_COND_ADD to

>          IFN_COND_FMA_REV and calls to IFN_COND_SUB to IFN_COND_FNMA_REV.

>          (convert_mult_to_fma): Handle calls to IFN_COND_ADD and

IFN_COND_SUB.
>          * genmatch.c (commutative_op): Handle CFN_COND_FMA_REV and

>          CFN_COND_FNMA_REV.

>          * config/aarch64/iterators.md (UNSPEC_COND_FMLA): New unspec.

>          (UNSPEC_COND_FMLS): Likewise.

>          (optab, sve_fp_op): Handle them.

>          (SVE_COND_INT_OP): Rename to...

>          (SVE_COND_INT2_OP): ...this.

>          (SVE_COND_FP_OP): Rename to...

>          (SVE_COND_FP2_OP): ...this.

>          (SVE_COND_FP3_OP): New iterator.

>          * config/aarch64/aarch64-sve.md (cond_<optab><mode>): Update

>          for new iterator names.  Add a pattern for SVE_COND_FP3_OP.


> gcc/testsuite/

>          * gcc.target/aarch64/sve/reduc_4.c: New test.

>          * gcc.target/aarch64/sve/reduc_6.c: Likewise.

>          * gcc.target/aarch64/sve/reduc_7.c: Likewise.


> Index: gcc/doc/md.texi

> ===================================================================

> --- gcc/doc/md.texi     2018-05-16 10:23:03.590853492 +0100

> +++ gcc/doc/md.texi     2018-05-16 10:23:03.886838736 +0100

> @@ -6367,6 +6367,32 @@ be in a normal C @samp{?:} condition.

>   Operands 0, 2 and 3 all have mode @var{m}, while operand 1 has the mode

>   returned by @code{TARGET_VECTORIZE_GET_MASK_MODE}.


> +@cindex @code{cond_fma_rev@var{mode}} instruction pattern

> +@item @samp{cond_fma_rev@var{mode}}

> +Similar to @samp{cond_add@var{m}}, but compute:

> +@smallexample

> +op0 = op1 ? fma (op3, op4, op2) : op2;

> +@end smallexample

> +for scalars and:

> +@smallexample

> +op0[I] = op1[I] ? fma (op3[I], op4[I], op2[I]) : op2[I];

> +@end smallexample

> +for vectors.  The @samp{_rev} indicates that the addend (operand 2)

> +comes first.

> +

> +@cindex @code{cond_fnma_rev@var{mode}} instruction pattern

> +@item @samp{cond_fnma_rev@var{mode}}

> +Similar to @samp{cond_fma_rev@var{m}}, but negate operand 3 before

> +multiplying it.  That is, the instruction performs:

> +@smallexample

> +op0 = op1 ? fma (-op3, op4, op2) : op2;

> +@end smallexample

> +for scalars and:

> +@smallexample

> +op0[I] = op1[I] ? fma (-op3[I], op4[I], op2[I]) : op2[I];

> +@end smallexample

> +for vectors.

> +

>   @cindex @code{neg@var{mode}cc} instruction pattern

>   @item @samp{neg@var{mode}cc}

>   Similar to @samp{mov@var{mode}cc} but for conditional negation.

Conditionally
> Index: gcc/optabs.def

> ===================================================================

> --- gcc/optabs.def      2018-05-16 10:23:03.590853492 +0100

> +++ gcc/optabs.def      2018-05-16 10:23:03.887838686 +0100

> @@ -222,6 +222,8 @@ OPTAB_D (notcc_optab, "not$acc")

>   OPTAB_D (movcc_optab, "mov$acc")

>   OPTAB_D (cond_add_optab, "cond_add$a")

>   OPTAB_D (cond_sub_optab, "cond_sub$a")

> +OPTAB_D (cond_fma_rev_optab, "cond_fma_rev$a")

> +OPTAB_D (cond_fnma_rev_optab, "cond_fnma_rev$a")

>   OPTAB_D (cond_and_optab, "cond_and$a")

>   OPTAB_D (cond_ior_optab, "cond_ior$a")

>   OPTAB_D (cond_xor_optab, "cond_xor$a")

> Index: gcc/internal-fn.def

> ===================================================================

> --- gcc/internal-fn.def 2018-05-16 10:23:03.590853492 +0100

> +++ gcc/internal-fn.def 2018-05-16 10:23:03.887838686 +0100

> @@ -59,7 +59,8 @@ along with GCC; see the file COPYING3.

>      - binary: a normal binary optab, such as vec_interleave_lo_<mode>

>      - ternary: a normal ternary optab, such as fma<mode>4


> -   - cond_binary: a conditional binary optab, such as add<mode>cc

> +   - cond_binary: a conditional binary optab, such as cond_add<mode>

> +   - cond_ternary: a conditional ternary optab, such as

cond_fma_rev<mode>

>      - fold_left: for scalar = FN (scalar, vector), keyed off the vector

mode

> @@ -143,6 +144,9 @@ DEF_INTERNAL_OPTAB_FN (FMS, ECF_CONST, f

>   DEF_INTERNAL_OPTAB_FN (FNMA, ECF_CONST, fnma, ternary)

>   DEF_INTERNAL_OPTAB_FN (FNMS, ECF_CONST, fnms, ternary)


> +DEF_INTERNAL_OPTAB_FN (COND_FMA_REV, ECF_CONST, cond_fma_rev,

cond_ternary)
> +DEF_INTERNAL_OPTAB_FN (COND_FNMA_REV, ECF_CONST, cond_fnma_rev,

cond_ternary)
> +

>   DEF_INTERNAL_OPTAB_FN (COND_ADD, ECF_CONST, cond_add, cond_binary)

>   DEF_INTERNAL_OPTAB_FN (COND_SUB, ECF_CONST, cond_sub, cond_binary)

>   DEF_INTERNAL_SIGNED_OPTAB_FN (COND_MIN, ECF_CONST, first,

> Index: gcc/internal-fn.h

> ===================================================================

> --- gcc/internal-fn.h   2018-05-16 10:23:03.590853492 +0100

> +++ gcc/internal-fn.h   2018-05-16 10:23:03.887838686 +0100

> @@ -191,6 +191,8 @@ direct_internal_fn_supported_p (internal

>   extern bool set_edom_supported_p (void);


>   extern internal_fn get_conditional_internal_fn (tree_code);

> +extern bool can_interpret_as_conditional_op_p (gimple *, tree_code *,

> +                                              tree *, tree (&)[3]);


>   extern bool internal_load_fn_p (internal_fn);

>   extern bool internal_store_fn_p (internal_fn);

> Index: gcc/internal-fn.c

> ===================================================================

> --- gcc/internal-fn.c   2018-05-16 10:23:03.590853492 +0100

> +++ gcc/internal-fn.c   2018-05-16 10:23:03.887838686 +0100

> @@ -93,6 +93,7 @@ #define binary_direct { 0, 0, true }

>   #define ternary_direct { 0, 0, true }

>   #define cond_unary_direct { 1, 1, true }

>   #define cond_binary_direct { 1, 1, true }

> +#define cond_ternary_direct { 1, 1, true }

>   #define while_direct { 0, 2, false }

>   #define fold_extract_direct { 2, 2, false }

>   #define fold_left_direct { 1, 1, false }

> @@ -2972,6 +2973,9 @@ #define expand_cond_unary_optab_fn(FN, S

>   #define expand_cond_binary_optab_fn(FN, STMT, OPTAB) \

>     expand_direct_optab_fn (FN, STMT, OPTAB, 3)


> +#define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \

> +  expand_direct_optab_fn (FN, STMT, OPTAB, 4)

> +

>   #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \

>     expand_direct_optab_fn (FN, STMT, OPTAB, 3)


> @@ -3054,6 +3058,7 @@ #define direct_binary_optab_supported_p

>   #define direct_ternary_optab_supported_p direct_optab_supported_p

>   #define direct_cond_unary_optab_supported_p direct_optab_supported_p

>   #define direct_cond_binary_optab_supported_p direct_optab_supported_p

> +#define direct_cond_ternary_optab_supported_p direct_optab_supported_p

>   #define direct_mask_load_optab_supported_p direct_optab_supported_p

>   #define direct_load_lanes_optab_supported_p

multi_vector_optab_supported_p
>   #define direct_mask_load_lanes_optab_supported_p

multi_vector_optab_supported_p
> @@ -3198,6 +3203,17 @@ #define DEF_INTERNAL_FN(CODE, FLAGS, FNS

>     0

>   };


> +/* Invoke T(CODE, IFN) for each conditional function IFN that maps to a

> +   tree code CODE.  */

> +#define FOR_EACH_CODE_MAPPING(T) \

> +  T (PLUS_EXPR, IFN_COND_ADD) \

> +  T (MINUS_EXPR, IFN_COND_SUB) \

> +  T (MIN_EXPR, IFN_COND_MIN) \

> +  T (MAX_EXPR, IFN_COND_MAX) \

> +  T (BIT_AND_EXPR, IFN_COND_AND) \

> +  T (BIT_IOR_EXPR, IFN_COND_IOR) \

> +  T (BIT_XOR_EXPR, IFN_COND_XOR)

> +

>   /* Return a function that performs the conditional form of CODE, i.e.:


>        LHS = RHS1 ? RHS2 CODE RHS3 : RHS2

> @@ -3210,25 +3226,78 @@ get_conditional_internal_fn (tree_code c

>   {

>     switch (code)

>       {

> -    case PLUS_EXPR:

> -      return IFN_COND_ADD;

> -    case MINUS_EXPR:

> -      return IFN_COND_SUB;

> -    case MIN_EXPR:

> -      return IFN_COND_MIN;

> -    case MAX_EXPR:

> -      return IFN_COND_MAX;

> -    case BIT_AND_EXPR:

> -      return IFN_COND_AND;

> -    case BIT_IOR_EXPR:

> -      return IFN_COND_IOR;

> -    case BIT_XOR_EXPR:

> -      return IFN_COND_XOR;

> +#define CASE(CODE, IFN) case CODE: return IFN;

> +      FOR_EACH_CODE_MAPPING(CASE)

> +#undef CASE

>       default:

>         return IFN_LAST;

>       }

>   }


> +/* If IFN implements the conditional form of a tree code, return that

> +   tree code, otherwise return ERROR_MARK.  */

> +

> +static tree_code

> +conditional_internal_fn_code (internal_fn ifn)

> +{

> +  switch (ifn)

> +    {

> +#define CASE(CODE, IFN) case IFN: return CODE;

> +      FOR_EACH_CODE_MAPPING(CASE)

> +#undef CASE

> +    default:

> +      return ERROR_MARK;

> +    }

> +}

> +

> +/* Return true if STMT can be interpreted as a conditional tree code

> +   operation of the form:

> +

> +     LHS = COND ? OP (RHS1, ...) : RHS1;

> +

> +   operating elementwise if the operands are vectors.  This includes

> +   the case of an all-true COND, so that the operation always happens.

> +

> +   When returning true, set:

> +

> +   - *CODE_OUT to the tree code

> +   - *COND_OUT to the condition COND, or to NULL_TREE if the condition

> +     is known to be all-true

> +   - OPS[I] to operand I of *CODE_OUT.  */

> +

> +bool

> +can_interpret_as_conditional_op_p (gimple *stmt, tree_code *code_out,

> +                                  tree *cond_out, tree (&ops)[3])

> +{

> +  if (gassign *assign = dyn_cast <gassign *> (stmt))

> +    {

> +      *code_out = gimple_assign_rhs_code (assign);

> +      *cond_out = NULL_TREE;

> +      ops[0] = gimple_assign_rhs1 (assign);

> +      ops[1] = gimple_assign_rhs2 (assign);

> +      ops[2] = gimple_assign_rhs3 (assign);

> +      return true;

> +    }

> +  if (gcall *call = dyn_cast <gcall *> (stmt))

> +    if (gimple_call_internal_p (call))

> +      {

> +       internal_fn ifn = gimple_call_internal_fn (call);

> +       tree_code code = conditional_internal_fn_code (ifn);

> +       if (code != ERROR_MARK)

> +         {

> +           *code_out = code;

> +           *cond_out = gimple_call_arg (call, 0);

> +           if (integer_truep (*cond_out))

> +             *cond_out = NULL_TREE;

> +           unsigned int nargs = gimple_call_num_args (call) - 1;

> +           for (unsigned int i = 0; i < 3; ++i)

> +             ops[i] = i < nargs ? gimple_call_arg (call, i + 1) :

NULL_TREE;
> +           return true;

> +         }

> +      }

> +  return false;

> +}

> +

>   /* Return true if IFN is some form of load from memory.  */


>   bool

> Index: gcc/tree-ssa-math-opts.c

> ===================================================================

> --- gcc/tree-ssa-math-opts.c    2018-05-16 10:23:03.590853492 +0100

> +++ gcc/tree-ssa-math-opts.c    2018-05-16 10:23:03.889838586 +0100

> @@ -2640,6 +2640,24 @@ convert_plusminus_to_widen (gimple_stmt_

>     return true;

>   }


> +/* Return the internal function that implements:

> +

> +     LHS = COND ? A CODE B * C : A.  */

> +

> +static internal_fn

> +fused_cond_internal_fn (tree_code code)

> +{

> +  switch (code)

> +    {

> +    case PLUS_EXPR:

> +      return IFN_COND_FMA_REV;

> +    case MINUS_EXPR:

> +      return IFN_COND_FNMA_REV;

> +    default:

> +      gcc_unreachable ();

> +    }

> +}

> +

>   /* gimple_fold callback that "valueizes" everything.  */


>   static tree

> @@ -2663,7 +2681,6 @@ convert_mult_to_fma_1 (tree mul_result,

>     FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)

>       {

>         gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);

> -      enum tree_code use_code;

>         tree addop, mulop1 = op1, result = mul_result;

>         bool negate_p = false;

>         gimple_seq seq = NULL;

> @@ -2671,8 +2688,8 @@ convert_mult_to_fma_1 (tree mul_result,

>         if (is_gimple_debug (use_stmt))

>          continue;


> -      use_code = gimple_assign_rhs_code (use_stmt);

> -      if (use_code == NEGATE_EXPR)

> +      if (is_gimple_assign (use_stmt)

> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)

>          {

>            result = gimple_assign_lhs (use_stmt);

>            use_operand_p use_p;

> @@ -2683,23 +2700,30 @@ convert_mult_to_fma_1 (tree mul_result,


>            use_stmt = neguse_stmt;

>            gsi = gsi_for_stmt (use_stmt);

> -         use_code = gimple_assign_rhs_code (use_stmt);

>            negate_p = true;

>          }


> -      if (gimple_assign_rhs1 (use_stmt) == result)

> -       {

> -         addop = gimple_assign_rhs2 (use_stmt);

> -         /* a * b - c -> a * b + (-c)  */

> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)

> -           addop = gimple_build (&seq, NEGATE_EXPR, type, addop);

> -       }

> +      tree cond, ops[3];

> +      tree_code code;

> +      if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond,

ops))
> +       gcc_unreachable ();

> +      addop = ops[0] == result ? ops[1] : ops[0];

> +

> +      internal_fn ifn;

> +      if (cond)

> +       ifn = fused_cond_internal_fn (code);

>         else

>          {

> -         addop = gimple_assign_rhs1 (use_stmt);

> -         /* a - b * c -> (-b) * c + a */

> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)

> -           negate_p = !negate_p;

> +         ifn = IFN_FMA;

> +         if (code == MINUS_EXPR)

> +           {

> +             if (ops[0] == result)

> +               /* a * b - c -> a * b + (-c)  */

> +               addop = gimple_build (&seq, NEGATE_EXPR, type, addop);

> +             else

> +               /* a - b * c -> (-b) * c + a */

> +               negate_p = !negate_p;

> +           }

>          }


>         if (negate_p)

> @@ -2707,8 +2731,13 @@ convert_mult_to_fma_1 (tree mul_result,


>         if (seq)

>          gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);

> -      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,

addop);
> -      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));

> +

> +      if (ifn == IFN_FMA)

> +       fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,

addop);
> +      else

> +       fma_stmt = gimple_build_call_internal (ifn, 4, cond, addop,

> +                                              mulop1, op2);

> +      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));

>         gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal

(use_stmt));
>         gsi_replace (&gsi, fma_stmt, true);

>         /* Valueize aggressively so that we generate FMS, FNMA and FNMS

> @@ -2891,7 +2920,6 @@ convert_mult_to_fma (gimple *mul_stmt, t

>        as an addition.  */

>     FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)

>       {

> -      enum tree_code use_code;

>         tree result = mul_result;

>         bool negate_p = false;


> @@ -2912,13 +2940,9 @@ convert_mult_to_fma (gimple *mul_stmt, t

>         if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))

>          return false;


> -      if (!is_gimple_assign (use_stmt))

> -       return false;

> -

> -      use_code = gimple_assign_rhs_code (use_stmt);

> -

>         /* A negate on the multiplication leads to FNMA.  */

> -      if (use_code == NEGATE_EXPR)

> +      if (is_gimple_assign (use_stmt)

> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)

>          {

>            ssa_op_iter iter;

>            use_operand_p usep;

> @@ -2940,17 +2964,19 @@ convert_mult_to_fma (gimple *mul_stmt, t

>            use_stmt = neguse_stmt;

>            if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))

>              return false;

> -         if (!is_gimple_assign (use_stmt))

> -           return false;


> -         use_code = gimple_assign_rhs_code (use_stmt);

>            negate_p = true;

>          }


> -      switch (use_code)

> +      tree cond, ops[3];

> +      tree_code code;

> +      if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond,

ops))
> +       return false;

> +

> +      switch (code)

>          {

>          case MINUS_EXPR:

> -         if (gimple_assign_rhs2 (use_stmt) == result)

> +         if (ops[1] == result)

>              negate_p = !negate_p;

>            break;

>          case PLUS_EXPR:

> @@ -2960,47 +2986,52 @@ convert_mult_to_fma (gimple *mul_stmt, t

>            return false;

>          }


> -      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed

> -        by a MULT_EXPR that we'll visit later, we might be able to

> -        get a more profitable match with fnma.

> +      if (cond)

> +       {

> +         /* The multiplication must be the second operand.  */

> +         if (cond == result || ops[0] == result)

> +           return false;

> +         internal_fn ifn = fused_cond_internal_fn (code);

> +         if (!direct_internal_fn_supported_p (ifn, type, opt_type))

> +           return false;

> +       }

> +

> +      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that

> +        we'll visit later, we might be able to get a more profitable

> +        match with fnma.

>           OTOH, if we don't, a negate / fma pair has likely lower latency

>           that a mult / subtract pair.  */

> -      if (use_code == MINUS_EXPR && !negate_p

> -         && gimple_assign_rhs1 (use_stmt) == result

> +      if (code == MINUS_EXPR

> +         && !negate_p

> +         && ops[0] == result

>            && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)

> -         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))

> +         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)

> +         && TREE_CODE (ops[1]) == SSA_NAME

> +         && has_single_use (ops[1]))

>          {

> -         tree rhs2 = gimple_assign_rhs2 (use_stmt);

> -

> -         if (TREE_CODE (rhs2) == SSA_NAME)

> -           {

> -             gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);

> -             if (has_single_use (rhs2)

> -                 && is_gimple_assign (stmt2)

> -                 && gimple_assign_rhs_code (stmt2) == MULT_EXPR)

> -             return false;

> -           }

> +         gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);

> +         if (is_gimple_assign (stmt2)

> +             && gimple_assign_rhs_code (stmt2) == MULT_EXPR)

> +           return false;

>          }


> -      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);

> -      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);

>         /* We can't handle a * b + a * b.  */

> -      if (use_rhs1 == use_rhs2)

> +      if (ops[0] == ops[1])

>          return false;

>         /* If deferring, make sure we are not looking at an instruction

that
>           wouldn't have existed if we were not.  */

>         if (state->m_deferring_p

> -         && (state->m_mul_result_set.contains (use_rhs1)

> -             || state->m_mul_result_set.contains (use_rhs2)))

> +         && (state->m_mul_result_set.contains (ops[0])

> +             || state->m_mul_result_set.contains (ops[1])))

>          return false;


>         if (check_defer)

>          {

> -         tree use_lhs = gimple_assign_lhs (use_stmt);

> +         tree use_lhs = gimple_get_lhs (use_stmt);

>            if (state->m_last_result)

>              {

> -             if (use_rhs2 == state->m_last_result

> -                 || use_rhs1 == state->m_last_result)

> +             if (ops[1] == state->m_last_result

> +                 || ops[0] == state->m_last_result)

>                  defer = true;

>                else

>                  defer = false;

> @@ -3009,12 +3040,12 @@ convert_mult_to_fma (gimple *mul_stmt, t

>              {

>                gcc_checking_assert (!state->m_initial_phi);

>                gphi *phi;

> -             if (use_rhs1 == result)

> -               phi = result_of_phi (use_rhs2);

> +             if (ops[0] == result)

> +               phi = result_of_phi (ops[1]);

>                else

>                  {

> -                 gcc_assert (use_rhs2 == result);

> -                 phi = result_of_phi (use_rhs1);

> +                 gcc_assert (ops[1] == result);

> +                 phi = result_of_phi (ops[0]);

>                  }


>                if (phi)

> Index: gcc/genmatch.c

> ===================================================================

> --- gcc/genmatch.c      2018-05-16 10:23:03.590853492 +0100

> +++ gcc/genmatch.c      2018-05-16 10:23:03.887838686 +0100

> @@ -485,6 +485,10 @@ commutative_op (id_base *id)

>         case CFN_FNMS:

>          return 0;


> +      case CFN_COND_FMA_REV:

> +      case CFN_COND_FNMA_REV:

> +       return 2;

> +

>         default:

>          return -1;

>         }

> Index: gcc/config/aarch64/iterators.md

> ===================================================================

> --- gcc/config/aarch64/iterators.md     2018-05-16 10:23:03.590853492

+0100
> +++ gcc/config/aarch64/iterators.md     2018-05-16 10:23:03.886838736

+0100
> @@ -449,6 +449,8 @@ (define_c_enum "unspec"

>       UNSPEC_COND_AND    ; Used in aarch64-sve.md.

>       UNSPEC_COND_ORR    ; Used in aarch64-sve.md.

>       UNSPEC_COND_EOR    ; Used in aarch64-sve.md.

> +    UNSPEC_COND_FMLA   ; Used in aarch64-sve.md.

> +    UNSPEC_COND_FMLS   ; Used in aarch64-sve.md.

>       UNSPEC_COND_LT     ; Used in aarch64-sve.md.

>       UNSPEC_COND_LE     ; Used in aarch64-sve.md.

>       UNSPEC_COND_EQ     ; Used in aarch64-sve.md.

> @@ -1499,14 +1501,16 @@ (define_int_iterator UNPACK_UNSIGNED [UN


>   (define_int_iterator MUL_HIGHPART [UNSPEC_SMUL_HIGHPART

UNSPEC_UMUL_HIGHPART])

> -(define_int_iterator SVE_COND_INT_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB

> -                                     UNSPEC_COND_SMAX UNSPEC_COND_UMAX

> -                                     UNSPEC_COND_SMIN UNSPEC_COND_UMIN

> -                                     UNSPEC_COND_AND

> -                                     UNSPEC_COND_ORR

> -                                     UNSPEC_COND_EOR])

> +(define_int_iterator SVE_COND_INT2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB

> +                                      UNSPEC_COND_SMAX UNSPEC_COND_UMAX

> +                                      UNSPEC_COND_SMIN UNSPEC_COND_UMIN

> +                                      UNSPEC_COND_AND

> +                                      UNSPEC_COND_ORR

> +                                      UNSPEC_COND_EOR])


> -(define_int_iterator SVE_COND_FP_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB])

> +(define_int_iterator SVE_COND_FP2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB])

> +

> +(define_int_iterator SVE_COND_FP3_OP [UNSPEC_COND_FMLA UNSPEC_COND_FMLS])


>   (define_int_iterator SVE_COND_FP_CMP [UNSPEC_COND_LT UNSPEC_COND_LE

>                                        UNSPEC_COND_EQ UNSPEC_COND_NE

> @@ -1543,7 +1547,9 @@ (define_int_attr optab [(UNSPEC_ANDF "an

>                          (UNSPEC_COND_UMIN "umin")

>                          (UNSPEC_COND_AND "and")

>                          (UNSPEC_COND_ORR "ior")

> -                       (UNSPEC_COND_EOR "xor")])

> +                       (UNSPEC_COND_EOR "xor")

> +                       (UNSPEC_COND_FMLA "fma_rev")

> +                       (UNSPEC_COND_FMLS "fnma_rev")])


>   (define_int_attr  maxmin_uns [(UNSPEC_UMAXV "umax")

>                                (UNSPEC_UMINV "umin")

> @@ -1762,4 +1768,6 @@ (define_int_attr sve_int_op [(UNSPEC_CON

>                               (UNSPEC_COND_EOR "eor")])


>   (define_int_attr sve_fp_op [(UNSPEC_COND_ADD "fadd")

> -                           (UNSPEC_COND_SUB "fsub")])

> +                           (UNSPEC_COND_SUB "fsub")

> +                           (UNSPEC_COND_FMLA "fmla")

> +                           (UNSPEC_COND_FMLS "fmls")])

> Index: gcc/config/aarch64/aarch64-sve.md

> ===================================================================

> --- gcc/config/aarch64/aarch64-sve.md   2018-05-16 10:23:03.590853492

+0100
> +++ gcc/config/aarch64/aarch64-sve.md   2018-05-16 10:23:03.883838885

+0100
> @@ -1764,7 +1764,7 @@ (define_insn "cond_<optab><mode>"

>            [(match_operand:<VPRED> 1 "register_operand" "Upl")

>             (match_operand:SVE_I 2 "register_operand" "0")

>             (match_operand:SVE_I 3 "register_operand" "w")]

> -         SVE_COND_INT_OP))]

> +         SVE_COND_INT2_OP))]

>     "TARGET_SVE"

>     "<sve_int_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>"

>   )

> @@ -2543,11 +2543,23 @@ (define_insn "cond_<optab><mode>"

>            [(match_operand:<VPRED> 1 "register_operand" "Upl")

>             (match_operand:SVE_F 2 "register_operand" "0")

>             (match_operand:SVE_F 3 "register_operand" "w")]

> -         SVE_COND_FP_OP))]

> +         SVE_COND_FP2_OP))]

>     "TARGET_SVE"

>     "<sve_fp_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>"

>   )


> +(define_insn "cond_<optab><mode>"

> +  [(set (match_operand:SVE_F 0 "register_operand" "=w")

> +       (unspec:SVE_F

> +         [(match_operand:<VPRED> 1 "register_operand" "Upl")

> +          (match_operand:SVE_F 2 "register_operand" "0")

> +          (match_operand:SVE_F 3 "register_operand" "w")

> +          (match_operand:SVE_F 4 "register_operand" "w")]

> +         SVE_COND_FP3_OP))]

> +  "TARGET_SVE"

> +  "<sve_fp_op>\t%0.<Vetype>, %1/m, %3.<Vetype>, %4.<Vetype>"

> +)

> +

>   ;; Shift an SVE vector left and insert a scalar into element 0.

>   (define_insn "vec_shl_insert_<mode>"

>     [(set (match_operand:SVE_ALL 0 "register_operand" "=w, w")

> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c      2018-05-16

10:23:03.888838636 +0100
> @@ -0,0 +1,18 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +double

> +f (double *restrict a, double *restrict b, int *lookup)

> +{

> +  double res = 0.0;

> +  for (int i = 0; i < 512; ++i)

> +    res += a[lookup[i]] * b[i];

> +  return res;

> +}

> +

> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 }

} */
> +/* Check that the vector instructions are the only instructions.  */

> +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */

> +/* { dg-final { scan-assembler-not {\tfadd\t} } } */

> +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */

> +/* { dg-final { scan-assembler-not {\tsel\t} } } */

> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c      2018-05-16

10:23:03.888838636 +0100
> @@ -0,0 +1,17 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +#define REDUC(TYPE)                                            \

> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \

> +  {                                                            \

> +    TYPE sum = 0;                                              \

> +    for (int i = 0; i < count; ++i)                            \

> +      sum += x[i] * y[i];                                      \

> +    return sum;                                                        \

> +  }

> +

> +REDUC (float)

> +REDUC (double)

> +

> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 }

} */
> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 }

} */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c      2018-05-16

10:23:03.889838586 +0100
> @@ -0,0 +1,17 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +#define REDUC(TYPE)                                            \

> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \

> +  {                                                            \

> +    TYPE sum = 0;                                              \

> +    for (int i = 0; i < count; ++i)                            \

> +      sum -= x[i] * y[i];                                      \

> +    return sum;                                                        \

> +  }

> +

> +REDUC (float)

> +REDUC (double)

> +

> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 }

} */
> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 }

} */
Richard Sandiford May 24, 2018, 12:16 p.m. UTC | #2
Richard Biener <richard.guenther@gmail.com> writes:
> On Wed, May 16, 2018 at 11:26 AM Richard Sandiford <

> richard.sandiford@linaro.org> wrote:

>

>> This patch adds support for fusing a conditional add or subtract

>> with a multiplication, so that we can use fused multiply-add and

>> multiply-subtract operations for fully-masked reductions.  E.g.

>> for SVE we vectorise:

>

>>    double res = 0.0;

>>    for (int i = 0; i < n; ++i)

>>      res += x[i] * y[i];

>

>> using a fully-masked loop in which the loop body has the form:

>

>>    res_1 = PHI<0(preheader), res_2(latch)>;

>>    avec = IFN_MASK_LOAD (loop_mask, a)

>>    bvec = IFN_MASK_LOAD (loop_mask, b)

>>    prod = avec * bvec;

>>    res_2 = IFN_COND_ADD (loop_mask, res_1, prod);

>

>> where the last statement does the equivalent of:

>

>>    res_2 = loop_mask ? res_1 + prod : res_1;

>

>> (operating elementwise).  The point of the patch is to convert the last

>> two statements into a single internal function that is the equivalent of:

>

>>    res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;

>

>> (again operating elementwise).

>

>> All current conditional X operations have the form "do X or don't do X

>> to the first operand" (add/don't add to first operand, etc.).  However,

>> the FMA optabs and functions are ordered so that the accumulator comes

>> last.  There were two obvious ways of resolving this: break the

>> convention for conditional operators and have "add/don't add to the

>> final operand" or break the convention for FMA and put the accumulator

>> first.  The patch goes for the latter, but adds _REV to make it obvious

>> that the operands are in a different order.

>

> Eh.  I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM?

>

> That said, I don't really see the "do or not do to the first operand", it's

> "do or not do the operation on operands 1 to 2 (or 3)".  None of the

> current ops modify operand 1, they all produce a new value, no?


Yeah, neither the current functions nor these ones actually changed
operand 1.  It was all about deciding what the "else" value should be.
The _REV thing was a "fix" for the fact that we wanted the else value
to be the final operand of fma.

Of course, the real fix was to make all the IFN_COND_* functions take an
explicit else value, as you suggested in the review of the other patch
in the series.  So all this _REV stuff is redundant now.

Here's an updated version based on top of the IFN_COND_FMA patch
that I just posted.  Tested in the same way.

Thanks,
Richard

2018-05-24  Richard Sandiford  <richard.sandiford@linaro.org>
	    Alan Hayward  <alan.hayward@arm.com>
	    David Sherwood  <david.sherwood@arm.com>

gcc/
	* internal-fn.h (can_interpret_as_conditional_op_p): Declare.
	* internal-fn.c (can_interpret_as_conditional_op_p): New function.
	* tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
	plus and minus and convert them into IFN_COND_FMA-based sequences.
	(convert_mult_to_fma): Handle conditional plus and minus.

gcc/testsuite/
	* gcc.dg/vect/vect-fma-2.c: New test.
	* gcc.target/aarch64/sve/reduc_4.c: Likewise.
	* gcc.target/aarch64/sve/reduc_6.c: Likewise.
	* gcc.target/aarch64/sve/reduc_7.c: Likewise.

Index: gcc/internal-fn.h
===================================================================
--- gcc/internal-fn.h	2018-05-24 13:05:46.049605128 +0100
+++ gcc/internal-fn.h	2018-05-24 13:08:24.643987582 +0100
@@ -196,6 +196,9 @@ extern internal_fn get_conditional_inter
 extern internal_fn get_conditional_internal_fn (internal_fn);
 extern tree_code conditional_internal_fn_code (internal_fn);
 extern internal_fn get_unconditional_internal_fn (internal_fn);
+extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
+					       tree_code *, tree (&)[3],
+					       tree *);
 
 extern bool internal_load_fn_p (internal_fn);
 extern bool internal_store_fn_p (internal_fn);
Index: gcc/internal-fn.c
===================================================================
--- gcc/internal-fn.c	2018-05-24 13:05:46.048606357 +0100
+++ gcc/internal-fn.c	2018-05-24 13:08:24.643987582 +0100
@@ -3333,6 +3333,62 @@ #define CASE(NAME) case IFN_COND_##NAME:
     }
 }
 
+/* Return true if STMT can be interpreted as a conditional tree code
+   operation of the form:
+
+     LHS = COND ? OP (RHS1, ...) : ELSE;
+
+   operating elementwise if the operands are vectors.  This includes
+   the case of an all-true COND, so that the operation always happens.
+
+   When returning true, set:
+
+   - *COND_OUT to the condition COND, or to NULL_TREE if the condition
+     is known to be all-true
+   - *CODE_OUT to the tree code
+   - OPS[I] to operand I of *CODE_OUT
+   - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
+     condition is known to be all true.  */
+
+bool
+can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
+				   tree_code *code_out,
+				   tree (&ops)[3], tree *else_out)
+{
+  if (gassign *assign = dyn_cast <gassign *> (stmt))
+    {
+      *cond_out = NULL_TREE;
+      *code_out = gimple_assign_rhs_code (assign);
+      ops[0] = gimple_assign_rhs1 (assign);
+      ops[1] = gimple_assign_rhs2 (assign);
+      ops[2] = gimple_assign_rhs3 (assign);
+      *else_out = NULL_TREE;
+      return true;
+    }
+  if (gcall *call = dyn_cast <gcall *> (stmt))
+    if (gimple_call_internal_p (call))
+      {
+	internal_fn ifn = gimple_call_internal_fn (call);
+	tree_code code = conditional_internal_fn_code (ifn);
+	if (code != ERROR_MARK)
+	  {
+	    *cond_out = gimple_call_arg (call, 0);
+	    *code_out = code;
+	    unsigned int nops = gimple_call_num_args (call) - 2;
+	    for (unsigned int i = 0; i < 3; ++i)
+	      ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
+	    *else_out = gimple_call_arg (call, nops + 1);
+	    if (integer_truep (*cond_out))
+	      {
+		*cond_out = NULL_TREE;
+		*else_out = NULL_TREE;
+	      }
+	    return true;
+	  }
+      }
+  return false;
+}
+
 /* Return true if IFN is some form of load from memory.  */
 
 bool
Index: gcc/tree-ssa-math-opts.c
===================================================================
--- gcc/tree-ssa-math-opts.c	2018-05-18 09:26:37.749713749 +0100
+++ gcc/tree-ssa-math-opts.c	2018-05-24 13:08:24.644961583 +0100
@@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result,
   FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)
     {
       gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);
-      enum tree_code use_code;
       tree addop, mulop1 = op1, result = mul_result;
       bool negate_p = false;
       gimple_seq seq = NULL;
@@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result,
       if (is_gimple_debug (use_stmt))
 	continue;
 
-      use_code = gimple_assign_rhs_code (use_stmt);
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+	  && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
 	{
 	  result = gimple_assign_lhs (use_stmt);
 	  use_operand_p use_p;
@@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result,
 
 	  use_stmt = neguse_stmt;
 	  gsi = gsi_for_stmt (use_stmt);
-	  use_code = gimple_assign_rhs_code (use_stmt);
 	  negate_p = true;
 	}
 
-      if (gimple_assign_rhs1 (use_stmt) == result)
+      tree cond, else_value, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
+					      ops, &else_value))
+	gcc_unreachable ();
+      addop = ops[0] == result ? ops[1] : ops[0];
+
+      if (code == MINUS_EXPR)
 	{
-	  addop = gimple_assign_rhs2 (use_stmt);
-	  /* a * b - c -> a * b + (-c)  */
-	  if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
+	  if (ops[0] == result)
+	    /* a * b - c -> a * b + (-c)  */
 	    addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
-	}
-      else
-	{
-	  addop = gimple_assign_rhs1 (use_stmt);
-	  /* a - b * c -> (-b) * c + a */
-	  if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
+	  else
+	    /* a - b * c -> (-b) * c + a */
 	    negate_p = !negate_p;
 	}
 
@@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result,
 
       if (seq)
 	gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
-      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
-      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));
+
+      if (cond)
+	fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
+					       op2, addop, else_value);
+      else
+	fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
+      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));
       gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt));
       gsi_replace (&gsi, fma_stmt, true);
       /* Follow all SSA edges so that we generate FMS, FNMA and FNMS
@@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, t
      as an addition.  */
   FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)
     {
-      enum tree_code use_code;
       tree result = mul_result;
       bool negate_p = false;
 
@@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, t
       if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
 	return false;
 
-      if (!is_gimple_assign (use_stmt))
-	return false;
-
-      use_code = gimple_assign_rhs_code (use_stmt);
-
       /* A negate on the multiplication leads to FNMA.  */
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+	  && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
 	{
 	  ssa_op_iter iter;
 	  use_operand_p usep;
@@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, t
 	  use_stmt = neguse_stmt;
 	  if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
 	    return false;
-	  if (!is_gimple_assign (use_stmt))
-	    return false;
 
-	  use_code = gimple_assign_rhs_code (use_stmt);
 	  negate_p = true;
 	}
 
-      switch (use_code)
+      tree cond, else_value, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
+					      &else_value))
+	return false;
+
+      switch (code)
 	{
 	case MINUS_EXPR:
-	  if (gimple_assign_rhs2 (use_stmt) == result)
+	  if (ops[1] == result)
 	    negate_p = !negate_p;
 	  break;
 	case PLUS_EXPR:
@@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, t
 	  return false;
 	}
 
-      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed
-	 by a MULT_EXPR that we'll visit later, we might be able to
-	 get a more profitable match with fnma.
+      if (cond)
+	{
+	  if (cond == result || else_value == result)
+	    return false;
+	  if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
+	    return false;
+	}
+
+      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
+	 we'll visit later, we might be able to get a more profitable
+	 match with fnma.
 	 OTOH, if we don't, a negate / fma pair has likely lower latency
 	 that a mult / subtract pair.  */
-      if (use_code == MINUS_EXPR && !negate_p
-	  && gimple_assign_rhs1 (use_stmt) == result
+      if (code == MINUS_EXPR
+	  && !negate_p
+	  && ops[0] == result
 	  && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)
-	  && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))
+	  && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)
+	  && TREE_CODE (ops[1]) == SSA_NAME
+	  && has_single_use (ops[1]))
 	{
-	  tree rhs2 = gimple_assign_rhs2 (use_stmt);
-
-	  if (TREE_CODE (rhs2) == SSA_NAME)
-	    {
-	      gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);
-	      if (has_single_use (rhs2)
-		  && is_gimple_assign (stmt2)
-		  && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
-	      return false;
-	    }
+	  gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);
+	  if (is_gimple_assign (stmt2)
+	      && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
+	    return false;
 	}
 
-      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);
-      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);
       /* We can't handle a * b + a * b.  */
-      if (use_rhs1 == use_rhs2)
+      if (ops[0] == ops[1])
 	return false;
       /* If deferring, make sure we are not looking at an instruction that
 	 wouldn't have existed if we were not.  */
       if (state->m_deferring_p
-	  && (state->m_mul_result_set.contains (use_rhs1)
-	      || state->m_mul_result_set.contains (use_rhs2)))
+	  && (state->m_mul_result_set.contains (ops[0])
+	      || state->m_mul_result_set.contains (ops[1])))
 	return false;
 
       if (check_defer)
 	{
-	  tree use_lhs = gimple_assign_lhs (use_stmt);
+	  tree use_lhs = gimple_get_lhs (use_stmt);
 	  if (state->m_last_result)
 	    {
-	      if (use_rhs2 == state->m_last_result
-		  || use_rhs1 == state->m_last_result)
+	      if (ops[1] == state->m_last_result
+		  || ops[0] == state->m_last_result)
 		defer = true;
 	      else
 		defer = false;
@@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, t
 	    {
 	      gcc_checking_assert (!state->m_initial_phi);
 	      gphi *phi;
-	      if (use_rhs1 == result)
-		phi = result_of_phi (use_rhs2);
+	      if (ops[0] == result)
+		phi = result_of_phi (ops[1]);
 	      else
 		{
-		  gcc_assert (use_rhs2 == result);
-		  phi = result_of_phi (use_rhs1);
+		  gcc_assert (ops[1] == result);
+		  phi = result_of_phi (ops[0]);
 		}
 
 	      if (phi)
Index: gcc/testsuite/gcc.dg/vect/vect-fma-2.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.dg/vect/vect-fma-2.c	2018-05-24 13:08:24.643987582 +0100
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-additional-options "-fdump-tree-optimized -fassociative-math -fno-trapping-math -fno-signed-zeros" } */
+
+#include "tree-vect.h"
+
+#define N (VECTOR_BITS * 11 / 64 + 3)
+
+double
+dot_prod (double *x, double *y)
+{
+  double sum = 0;
+  for (int i = 0; i < N; ++i)
+    sum += x[i] * y[i];
+  return sum;
+}
+
+/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target { vect_double && { vect_fully_masked && scalar_all_fma } } } } } */
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c	2018-05-24 13:08:24.643987582 +0100
@@ -0,0 +1,18 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+double
+f (double *restrict a, double *restrict b, int *lookup)
+{
+  double res = 0.0;
+  for (int i = 0; i < 512; ++i)
+    res += a[lookup[i]] * b[i];
+  return res;
+}
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */
+/* Check that the vector instructions are the only instructions.  */
+/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */
+/* { dg-final { scan-assembler-not {\tfadd\t} } } */
+/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */
+/* { dg-final { scan-assembler-not {\tsel\t} } } */
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c	2018-05-24 13:08:24.643987582 +0100
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)						\
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)		\
+  {								\
+    TYPE sum = 0;						\
+    for (int i = 0; i < count; ++i)				\
+      sum += x[i] * y[i];					\
+    return sum;							\
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c	2018-05-24 13:08:24.643987582 +0100
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)						\
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)		\
+  {								\
+    TYPE sum = 0;						\
+    for (int i = 0; i < count; ++i)				\
+      sum -= x[i] * y[i];					\
+    return sum;							\
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
Richard Biener May 25, 2018, 7:22 a.m. UTC | #3
On Thu, May 24, 2018 at 2:17 PM Richard Sandiford <
richard.sandiford@linaro.org> wrote:

> Richard Biener <richard.guenther@gmail.com> writes:

> > On Wed, May 16, 2018 at 11:26 AM Richard Sandiford <

> > richard.sandiford@linaro.org> wrote:

> >

> >> This patch adds support for fusing a conditional add or subtract

> >> with a multiplication, so that we can use fused multiply-add and

> >> multiply-subtract operations for fully-masked reductions.  E.g.

> >> for SVE we vectorise:

> >

> >>    double res = 0.0;

> >>    for (int i = 0; i < n; ++i)

> >>      res += x[i] * y[i];

> >

> >> using a fully-masked loop in which the loop body has the form:

> >

> >>    res_1 = PHI<0(preheader), res_2(latch)>;

> >>    avec = IFN_MASK_LOAD (loop_mask, a)

> >>    bvec = IFN_MASK_LOAD (loop_mask, b)

> >>    prod = avec * bvec;

> >>    res_2 = IFN_COND_ADD (loop_mask, res_1, prod);

> >

> >> where the last statement does the equivalent of:

> >

> >>    res_2 = loop_mask ? res_1 + prod : res_1;

> >

> >> (operating elementwise).  The point of the patch is to convert the last

> >> two statements into a single internal function that is the equivalent

of:
> >

> >>    res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;

> >

> >> (again operating elementwise).

> >

> >> All current conditional X operations have the form "do X or don't do X

> >> to the first operand" (add/don't add to first operand, etc.).  However,

> >> the FMA optabs and functions are ordered so that the accumulator comes

> >> last.  There were two obvious ways of resolving this: break the

> >> convention for conditional operators and have "add/don't add to the

> >> final operand" or break the convention for FMA and put the accumulator

> >> first.  The patch goes for the latter, but adds _REV to make it obvious

> >> that the operands are in a different order.

> >

> > Eh.  I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM?

> >

> > That said, I don't really see the "do or not do to the first operand",

it's
> > "do or not do the operation on operands 1 to 2 (or 3)".  None of the

> > current ops modify operand 1, they all produce a new value, no?


> Yeah, neither the current functions nor these ones actually changed

> operand 1.  It was all about deciding what the "else" value should be.

> The _REV thing was a "fix" for the fact that we wanted the else value

> to be the final operand of fma.


> Of course, the real fix was to make all the IFN_COND_* functions take an

> explicit else value, as you suggested in the review of the other patch

> in the series.  So all this _REV stuff is redundant now.


> Here's an updated version based on top of the IFN_COND_FMA patch

> that I just posted.  Tested in the same way.


OK.

Thanks,
Richard.

> Thanks,

> Richard


> 2018-05-24  Richard Sandiford  <richard.sandiford@linaro.org>

>              Alan Hayward  <alan.hayward@arm.com>

>              David Sherwood  <david.sherwood@arm.com>


> gcc/

>          * internal-fn.h (can_interpret_as_conditional_op_p): Declare.

>          * internal-fn.c (can_interpret_as_conditional_op_p): New function.

>          * tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional

>          plus and minus and convert them into IFN_COND_FMA-based sequences.

>          (convert_mult_to_fma): Handle conditional plus and minus.


> gcc/testsuite/

>          * gcc.dg/vect/vect-fma-2.c: New test.

>          * gcc.target/aarch64/sve/reduc_4.c: Likewise.

>          * gcc.target/aarch64/sve/reduc_6.c: Likewise.

>          * gcc.target/aarch64/sve/reduc_7.c: Likewise.


> Index: gcc/internal-fn.h

> ===================================================================

> --- gcc/internal-fn.h   2018-05-24 13:05:46.049605128 +0100

> +++ gcc/internal-fn.h   2018-05-24 13:08:24.643987582 +0100

> @@ -196,6 +196,9 @@ extern internal_fn get_conditional_inter

>   extern internal_fn get_conditional_internal_fn (internal_fn);

>   extern tree_code conditional_internal_fn_code (internal_fn);

>   extern internal_fn get_unconditional_internal_fn (internal_fn);

> +extern bool can_interpret_as_conditional_op_p (gimple *, tree *,

> +                                              tree_code *, tree (&)[3],

> +                                              tree *);


>   extern bool internal_load_fn_p (internal_fn);

>   extern bool internal_store_fn_p (internal_fn);

> Index: gcc/internal-fn.c

> ===================================================================

> --- gcc/internal-fn.c   2018-05-24 13:05:46.048606357 +0100

> +++ gcc/internal-fn.c   2018-05-24 13:08:24.643987582 +0100

> @@ -3333,6 +3333,62 @@ #define CASE(NAME) case IFN_COND_##NAME:

>       }

>   }


> +/* Return true if STMT can be interpreted as a conditional tree code

> +   operation of the form:

> +

> +     LHS = COND ? OP (RHS1, ...) : ELSE;

> +

> +   operating elementwise if the operands are vectors.  This includes

> +   the case of an all-true COND, so that the operation always happens.

> +

> +   When returning true, set:

> +

> +   - *COND_OUT to the condition COND, or to NULL_TREE if the condition

> +     is known to be all-true

> +   - *CODE_OUT to the tree code

> +   - OPS[I] to operand I of *CODE_OUT

> +   - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the

> +     condition is known to be all true.  */

> +

> +bool

> +can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,

> +                                  tree_code *code_out,

> +                                  tree (&ops)[3], tree *else_out)

> +{

> +  if (gassign *assign = dyn_cast <gassign *> (stmt))

> +    {

> +      *cond_out = NULL_TREE;

> +      *code_out = gimple_assign_rhs_code (assign);

> +      ops[0] = gimple_assign_rhs1 (assign);

> +      ops[1] = gimple_assign_rhs2 (assign);

> +      ops[2] = gimple_assign_rhs3 (assign);

> +      *else_out = NULL_TREE;

> +      return true;

> +    }

> +  if (gcall *call = dyn_cast <gcall *> (stmt))

> +    if (gimple_call_internal_p (call))

> +      {

> +       internal_fn ifn = gimple_call_internal_fn (call);

> +       tree_code code = conditional_internal_fn_code (ifn);

> +       if (code != ERROR_MARK)

> +         {

> +           *cond_out = gimple_call_arg (call, 0);

> +           *code_out = code;

> +           unsigned int nops = gimple_call_num_args (call) - 2;

> +           for (unsigned int i = 0; i < 3; ++i)

> +             ops[i] = i < nops ? gimple_call_arg (call, i + 1) :

NULL_TREE;
> +           *else_out = gimple_call_arg (call, nops + 1);

> +           if (integer_truep (*cond_out))

> +             {

> +               *cond_out = NULL_TREE;

> +               *else_out = NULL_TREE;

> +             }

> +           return true;

> +         }

> +      }

> +  return false;

> +}

> +

>   /* Return true if IFN is some form of load from memory.  */


>   bool

> Index: gcc/tree-ssa-math-opts.c

> ===================================================================

> --- gcc/tree-ssa-math-opts.c    2018-05-18 09:26:37.749713749 +0100

> +++ gcc/tree-ssa-math-opts.c    2018-05-24 13:08:24.644961583 +0100

> @@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result,

>     FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)

>       {

>         gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);

> -      enum tree_code use_code;

>         tree addop, mulop1 = op1, result = mul_result;

>         bool negate_p = false;

>         gimple_seq seq = NULL;

> @@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result,

>         if (is_gimple_debug (use_stmt))

>          continue;


> -      use_code = gimple_assign_rhs_code (use_stmt);

> -      if (use_code == NEGATE_EXPR)

> +      if (is_gimple_assign (use_stmt)

> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)

>          {

>            result = gimple_assign_lhs (use_stmt);

>            use_operand_p use_p;

> @@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result,


>            use_stmt = neguse_stmt;

>            gsi = gsi_for_stmt (use_stmt);

> -         use_code = gimple_assign_rhs_code (use_stmt);

>            negate_p = true;

>          }


> -      if (gimple_assign_rhs1 (use_stmt) == result)

> +      tree cond, else_value, ops[3];

> +      tree_code code;

> +      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,

> +                                             ops, &else_value))

> +       gcc_unreachable ();

> +      addop = ops[0] == result ? ops[1] : ops[0];

> +

> +      if (code == MINUS_EXPR)

>          {

> -         addop = gimple_assign_rhs2 (use_stmt);

> -         /* a * b - c -> a * b + (-c)  */

> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)

> +         if (ops[0] == result)

> +           /* a * b - c -> a * b + (-c)  */

>              addop = gimple_build (&seq, NEGATE_EXPR, type, addop);

> -       }

> -      else

> -       {

> -         addop = gimple_assign_rhs1 (use_stmt);

> -         /* a - b * c -> (-b) * c + a */

> -         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)

> +         else

> +           /* a - b * c -> (-b) * c + a */

>              negate_p = !negate_p;

>          }


> @@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result,


>         if (seq)

>          gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);

> -      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,

addop);
> -      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));

> +

> +      if (cond)

> +       fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond,

mulop1,
> +                                              op2, addop, else_value);

> +      else

> +       fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2,

addop);
> +      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));

>         gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal

(use_stmt));
>         gsi_replace (&gsi, fma_stmt, true);

>         /* Follow all SSA edges so that we generate FMS, FNMA and FNMS

> @@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, t

>        as an addition.  */

>     FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)

>       {

> -      enum tree_code use_code;

>         tree result = mul_result;

>         bool negate_p = false;


> @@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, t

>         if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))

>          return false;


> -      if (!is_gimple_assign (use_stmt))

> -       return false;

> -

> -      use_code = gimple_assign_rhs_code (use_stmt);

> -

>         /* A negate on the multiplication leads to FNMA.  */

> -      if (use_code == NEGATE_EXPR)

> +      if (is_gimple_assign (use_stmt)

> +         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)

>          {

>            ssa_op_iter iter;

>            use_operand_p usep;

> @@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, t

>            use_stmt = neguse_stmt;

>            if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))

>              return false;

> -         if (!is_gimple_assign (use_stmt))

> -           return false;


> -         use_code = gimple_assign_rhs_code (use_stmt);

>            negate_p = true;

>          }


> -      switch (use_code)

> +      tree cond, else_value, ops[3];

> +      tree_code code;

> +      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,

ops,
> +                                             &else_value))

> +       return false;

> +

> +      switch (code)

>          {

>          case MINUS_EXPR:

> -         if (gimple_assign_rhs2 (use_stmt) == result)

> +         if (ops[1] == result)

>              negate_p = !negate_p;

>            break;

>          case PLUS_EXPR:

> @@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, t

>            return false;

>          }


> -      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed

> -        by a MULT_EXPR that we'll visit later, we might be able to

> -        get a more profitable match with fnma.

> +      if (cond)

> +       {

> +         if (cond == result || else_value == result)

> +           return false;

> +         if (!direct_internal_fn_supported_p (IFN_COND_FMA, type,

opt_type))
> +           return false;

> +       }

> +

> +      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that

> +        we'll visit later, we might be able to get a more profitable

> +        match with fnma.

>           OTOH, if we don't, a negate / fma pair has likely lower latency

>           that a mult / subtract pair.  */

> -      if (use_code == MINUS_EXPR && !negate_p

> -         && gimple_assign_rhs1 (use_stmt) == result

> +      if (code == MINUS_EXPR

> +         && !negate_p

> +         && ops[0] == result

>            && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)

> -         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))

> +         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)

> +         && TREE_CODE (ops[1]) == SSA_NAME

> +         && has_single_use (ops[1]))

>          {

> -         tree rhs2 = gimple_assign_rhs2 (use_stmt);

> -

> -         if (TREE_CODE (rhs2) == SSA_NAME)

> -           {

> -             gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);

> -             if (has_single_use (rhs2)

> -                 && is_gimple_assign (stmt2)

> -                 && gimple_assign_rhs_code (stmt2) == MULT_EXPR)

> -             return false;

> -           }

> +         gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);

> +         if (is_gimple_assign (stmt2)

> +             && gimple_assign_rhs_code (stmt2) == MULT_EXPR)

> +           return false;

>          }


> -      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);

> -      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);

>         /* We can't handle a * b + a * b.  */

> -      if (use_rhs1 == use_rhs2)

> +      if (ops[0] == ops[1])

>          return false;

>         /* If deferring, make sure we are not looking at an instruction

that
>           wouldn't have existed if we were not.  */

>         if (state->m_deferring_p

> -         && (state->m_mul_result_set.contains (use_rhs1)

> -             || state->m_mul_result_set.contains (use_rhs2)))

> +         && (state->m_mul_result_set.contains (ops[0])

> +             || state->m_mul_result_set.contains (ops[1])))

>          return false;


>         if (check_defer)

>          {

> -         tree use_lhs = gimple_assign_lhs (use_stmt);

> +         tree use_lhs = gimple_get_lhs (use_stmt);

>            if (state->m_last_result)

>              {

> -             if (use_rhs2 == state->m_last_result

> -                 || use_rhs1 == state->m_last_result)

> +             if (ops[1] == state->m_last_result

> +                 || ops[0] == state->m_last_result)

>                  defer = true;

>                else

>                  defer = false;

> @@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, t

>              {

>                gcc_checking_assert (!state->m_initial_phi);

>                gphi *phi;

> -             if (use_rhs1 == result)

> -               phi = result_of_phi (use_rhs2);

> +             if (ops[0] == result)

> +               phi = result_of_phi (ops[1]);

>                else

>                  {

> -                 gcc_assert (use_rhs2 == result);

> -                 phi = result_of_phi (use_rhs1);

> +                 gcc_assert (ops[1] == result);

> +                 phi = result_of_phi (ops[0]);

>                  }


>                if (phi)

> Index: gcc/testsuite/gcc.dg/vect/vect-fma-2.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.dg/vect/vect-fma-2.c      2018-05-24

13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@

> +/* { dg-do compile } */

> +/* { dg-additional-options "-fdump-tree-optimized -fassociative-math

-fno-trapping-math -fno-signed-zeros" } */
> +

> +#include "tree-vect.h"

> +

> +#define N (VECTOR_BITS * 11 / 64 + 3)

> +

> +double

> +dot_prod (double *x, double *y)

> +{

> +  double sum = 0;

> +  for (int i = 0; i < N; ++i)

> +    sum += x[i] * y[i];

> +  return sum;

> +}

> +

> +/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target {

vect_double && { vect_fully_masked && scalar_all_fma } } } } } */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c      2018-05-24

13:08:24.643987582 +0100
> @@ -0,0 +1,18 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +double

> +f (double *restrict a, double *restrict b, int *lookup)

> +{

> +  double res = 0.0;

> +  for (int i = 0; i < 512; ++i)

> +    res += a[lookup[i]] * b[i];

> +  return res;

> +}

> +

> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 }

} */
> +/* Check that the vector instructions are the only instructions.  */

> +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */

> +/* { dg-final { scan-assembler-not {\tfadd\t} } } */

> +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */

> +/* { dg-final { scan-assembler-not {\tsel\t} } } */

> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c      2018-05-24

13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +#define REDUC(TYPE)                                            \

> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \

> +  {                                                            \

> +    TYPE sum = 0;                                              \

> +    for (int i = 0; i < count; ++i)                            \

> +      sum += x[i] * y[i];                                      \

> +    return sum;                                                        \

> +  }

> +

> +REDUC (float)

> +REDUC (double)

> +

> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 }

} */
> +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 }

} */
> Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c

> ===================================================================

> --- /dev/null   2018-04-20 16:19:46.369131350 +0100

> +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c      2018-05-24

13:08:24.643987582 +0100
> @@ -0,0 +1,17 @@

> +/* { dg-do compile } */

> +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */

> +

> +#define REDUC(TYPE)                                            \

> +  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \

> +  {                                                            \

> +    TYPE sum = 0;                                              \

> +    for (int i = 0; i < count; ++i)                            \

> +      sum -= x[i] * y[i];                                      \

> +    return sum;                                                        \

> +  }

> +

> +REDUC (float)

> +REDUC (double)

> +

> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 }

} */
> +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 }

} */
diff mbox series

Patch

Index: gcc/doc/md.texi
===================================================================
--- gcc/doc/md.texi	2018-05-16 10:23:03.590853492 +0100
+++ gcc/doc/md.texi	2018-05-16 10:23:03.886838736 +0100
@@ -6367,6 +6367,32 @@  be in a normal C @samp{?:} condition.
 Operands 0, 2 and 3 all have mode @var{m}, while operand 1 has the mode
 returned by @code{TARGET_VECTORIZE_GET_MASK_MODE}.
 
+@cindex @code{cond_fma_rev@var{mode}} instruction pattern
+@item @samp{cond_fma_rev@var{mode}}
+Similar to @samp{cond_add@var{m}}, but compute:
+@smallexample
+op0 = op1 ? fma (op3, op4, op2) : op2;
+@end smallexample
+for scalars and:
+@smallexample
+op0[I] = op1[I] ? fma (op3[I], op4[I], op2[I]) : op2[I];
+@end smallexample
+for vectors.  The @samp{_rev} indicates that the addend (operand 2)
+comes first.
+
+@cindex @code{cond_fnma_rev@var{mode}} instruction pattern
+@item @samp{cond_fnma_rev@var{mode}}
+Similar to @samp{cond_fma_rev@var{m}}, but negate operand 3 before
+multiplying it.  That is, the instruction performs:
+@smallexample
+op0 = op1 ? fma (-op3, op4, op2) : op2;
+@end smallexample
+for scalars and:
+@smallexample
+op0[I] = op1[I] ? fma (-op3[I], op4[I], op2[I]) : op2[I];
+@end smallexample
+for vectors.
+
 @cindex @code{neg@var{mode}cc} instruction pattern
 @item @samp{neg@var{mode}cc}
 Similar to @samp{mov@var{mode}cc} but for conditional negation.  Conditionally
Index: gcc/optabs.def
===================================================================
--- gcc/optabs.def	2018-05-16 10:23:03.590853492 +0100
+++ gcc/optabs.def	2018-05-16 10:23:03.887838686 +0100
@@ -222,6 +222,8 @@  OPTAB_D (notcc_optab, "not$acc")
 OPTAB_D (movcc_optab, "mov$acc")
 OPTAB_D (cond_add_optab, "cond_add$a")
 OPTAB_D (cond_sub_optab, "cond_sub$a")
+OPTAB_D (cond_fma_rev_optab, "cond_fma_rev$a")
+OPTAB_D (cond_fnma_rev_optab, "cond_fnma_rev$a")
 OPTAB_D (cond_and_optab, "cond_and$a")
 OPTAB_D (cond_ior_optab, "cond_ior$a")
 OPTAB_D (cond_xor_optab, "cond_xor$a")
Index: gcc/internal-fn.def
===================================================================
--- gcc/internal-fn.def	2018-05-16 10:23:03.590853492 +0100
+++ gcc/internal-fn.def	2018-05-16 10:23:03.887838686 +0100
@@ -59,7 +59,8 @@  along with GCC; see the file COPYING3.
    - binary: a normal binary optab, such as vec_interleave_lo_<mode>
    - ternary: a normal ternary optab, such as fma<mode>4
 
-   - cond_binary: a conditional binary optab, such as add<mode>cc
+   - cond_binary: a conditional binary optab, such as cond_add<mode>
+   - cond_ternary: a conditional ternary optab, such as cond_fma_rev<mode>
 
    - fold_left: for scalar = FN (scalar, vector), keyed off the vector mode
 
@@ -143,6 +144,9 @@  DEF_INTERNAL_OPTAB_FN (FMS, ECF_CONST, f
 DEF_INTERNAL_OPTAB_FN (FNMA, ECF_CONST, fnma, ternary)
 DEF_INTERNAL_OPTAB_FN (FNMS, ECF_CONST, fnms, ternary)
 
+DEF_INTERNAL_OPTAB_FN (COND_FMA_REV, ECF_CONST, cond_fma_rev, cond_ternary)
+DEF_INTERNAL_OPTAB_FN (COND_FNMA_REV, ECF_CONST, cond_fnma_rev, cond_ternary)
+
 DEF_INTERNAL_OPTAB_FN (COND_ADD, ECF_CONST, cond_add, cond_binary)
 DEF_INTERNAL_OPTAB_FN (COND_SUB, ECF_CONST, cond_sub, cond_binary)
 DEF_INTERNAL_SIGNED_OPTAB_FN (COND_MIN, ECF_CONST, first,
Index: gcc/internal-fn.h
===================================================================
--- gcc/internal-fn.h	2018-05-16 10:23:03.590853492 +0100
+++ gcc/internal-fn.h	2018-05-16 10:23:03.887838686 +0100
@@ -191,6 +191,8 @@  direct_internal_fn_supported_p (internal
 extern bool set_edom_supported_p (void);
 
 extern internal_fn get_conditional_internal_fn (tree_code);
+extern bool can_interpret_as_conditional_op_p (gimple *, tree_code *,
+					       tree *, tree (&)[3]);
 
 extern bool internal_load_fn_p (internal_fn);
 extern bool internal_store_fn_p (internal_fn);
Index: gcc/internal-fn.c
===================================================================
--- gcc/internal-fn.c	2018-05-16 10:23:03.590853492 +0100
+++ gcc/internal-fn.c	2018-05-16 10:23:03.887838686 +0100
@@ -93,6 +93,7 @@  #define binary_direct { 0, 0, true }
 #define ternary_direct { 0, 0, true }
 #define cond_unary_direct { 1, 1, true }
 #define cond_binary_direct { 1, 1, true }
+#define cond_ternary_direct { 1, 1, true }
 #define while_direct { 0, 2, false }
 #define fold_extract_direct { 2, 2, false }
 #define fold_left_direct { 1, 1, false }
@@ -2972,6 +2973,9 @@  #define expand_cond_unary_optab_fn(FN, S
 #define expand_cond_binary_optab_fn(FN, STMT, OPTAB) \
   expand_direct_optab_fn (FN, STMT, OPTAB, 3)
 
+#define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \
+  expand_direct_optab_fn (FN, STMT, OPTAB, 4)
+
 #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \
   expand_direct_optab_fn (FN, STMT, OPTAB, 3)
 
@@ -3054,6 +3058,7 @@  #define direct_binary_optab_supported_p
 #define direct_ternary_optab_supported_p direct_optab_supported_p
 #define direct_cond_unary_optab_supported_p direct_optab_supported_p
 #define direct_cond_binary_optab_supported_p direct_optab_supported_p
+#define direct_cond_ternary_optab_supported_p direct_optab_supported_p
 #define direct_mask_load_optab_supported_p direct_optab_supported_p
 #define direct_load_lanes_optab_supported_p multi_vector_optab_supported_p
 #define direct_mask_load_lanes_optab_supported_p multi_vector_optab_supported_p
@@ -3198,6 +3203,17 @@  #define DEF_INTERNAL_FN(CODE, FLAGS, FNS
   0
 };
 
+/* Invoke T(CODE, IFN) for each conditional function IFN that maps to a
+   tree code CODE.  */
+#define FOR_EACH_CODE_MAPPING(T) \
+  T (PLUS_EXPR, IFN_COND_ADD) \
+  T (MINUS_EXPR, IFN_COND_SUB) \
+  T (MIN_EXPR, IFN_COND_MIN) \
+  T (MAX_EXPR, IFN_COND_MAX) \
+  T (BIT_AND_EXPR, IFN_COND_AND) \
+  T (BIT_IOR_EXPR, IFN_COND_IOR) \
+  T (BIT_XOR_EXPR, IFN_COND_XOR)
+
 /* Return a function that performs the conditional form of CODE, i.e.:
 
      LHS = RHS1 ? RHS2 CODE RHS3 : RHS2
@@ -3210,25 +3226,78 @@  get_conditional_internal_fn (tree_code c
 {
   switch (code)
     {
-    case PLUS_EXPR:
-      return IFN_COND_ADD;
-    case MINUS_EXPR:
-      return IFN_COND_SUB;
-    case MIN_EXPR:
-      return IFN_COND_MIN;
-    case MAX_EXPR:
-      return IFN_COND_MAX;
-    case BIT_AND_EXPR:
-      return IFN_COND_AND;
-    case BIT_IOR_EXPR:
-      return IFN_COND_IOR;
-    case BIT_XOR_EXPR:
-      return IFN_COND_XOR;
+#define CASE(CODE, IFN) case CODE: return IFN;
+      FOR_EACH_CODE_MAPPING(CASE)
+#undef CASE
     default:
       return IFN_LAST;
     }
 }
 
+/* If IFN implements the conditional form of a tree code, return that
+   tree code, otherwise return ERROR_MARK.  */
+
+static tree_code
+conditional_internal_fn_code (internal_fn ifn)
+{
+  switch (ifn)
+    {
+#define CASE(CODE, IFN) case IFN: return CODE;
+      FOR_EACH_CODE_MAPPING(CASE)
+#undef CASE
+    default:
+      return ERROR_MARK;
+    }
+}
+
+/* Return true if STMT can be interpreted as a conditional tree code
+   operation of the form:
+
+     LHS = COND ? OP (RHS1, ...) : RHS1;
+
+   operating elementwise if the operands are vectors.  This includes
+   the case of an all-true COND, so that the operation always happens.
+
+   When returning true, set:
+
+   - *CODE_OUT to the tree code
+   - *COND_OUT to the condition COND, or to NULL_TREE if the condition
+     is known to be all-true
+   - OPS[I] to operand I of *CODE_OUT.  */
+
+bool
+can_interpret_as_conditional_op_p (gimple *stmt, tree_code *code_out,
+				   tree *cond_out, tree (&ops)[3])
+{
+  if (gassign *assign = dyn_cast <gassign *> (stmt))
+    {
+      *code_out = gimple_assign_rhs_code (assign);
+      *cond_out = NULL_TREE;
+      ops[0] = gimple_assign_rhs1 (assign);
+      ops[1] = gimple_assign_rhs2 (assign);
+      ops[2] = gimple_assign_rhs3 (assign);
+      return true;
+    }
+  if (gcall *call = dyn_cast <gcall *> (stmt))
+    if (gimple_call_internal_p (call))
+      {
+	internal_fn ifn = gimple_call_internal_fn (call);
+	tree_code code = conditional_internal_fn_code (ifn);
+	if (code != ERROR_MARK)
+	  {
+	    *code_out = code;
+	    *cond_out = gimple_call_arg (call, 0);
+	    if (integer_truep (*cond_out))
+	      *cond_out = NULL_TREE;
+	    unsigned int nargs = gimple_call_num_args (call) - 1;
+	    for (unsigned int i = 0; i < 3; ++i)
+	      ops[i] = i < nargs ? gimple_call_arg (call, i + 1) : NULL_TREE;
+	    return true;
+	  }
+      }
+  return false;
+}
+
 /* Return true if IFN is some form of load from memory.  */
 
 bool
Index: gcc/tree-ssa-math-opts.c
===================================================================
--- gcc/tree-ssa-math-opts.c	2018-05-16 10:23:03.590853492 +0100
+++ gcc/tree-ssa-math-opts.c	2018-05-16 10:23:03.889838586 +0100
@@ -2640,6 +2640,24 @@  convert_plusminus_to_widen (gimple_stmt_
   return true;
 }
 
+/* Return the internal function that implements:
+
+     LHS = COND ? A CODE B * C : A.  */
+
+static internal_fn
+fused_cond_internal_fn (tree_code code)
+{
+  switch (code)
+    {
+    case PLUS_EXPR:
+      return IFN_COND_FMA_REV;
+    case MINUS_EXPR:
+      return IFN_COND_FNMA_REV;
+    default:
+      gcc_unreachable ();
+    }
+}
+
 /* gimple_fold callback that "valueizes" everything.  */
 
 static tree
@@ -2663,7 +2681,6 @@  convert_mult_to_fma_1 (tree mul_result,
   FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)
     {
       gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);
-      enum tree_code use_code;
       tree addop, mulop1 = op1, result = mul_result;
       bool negate_p = false;
       gimple_seq seq = NULL;
@@ -2671,8 +2688,8 @@  convert_mult_to_fma_1 (tree mul_result,
       if (is_gimple_debug (use_stmt))
 	continue;
 
-      use_code = gimple_assign_rhs_code (use_stmt);
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+	  && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
 	{
 	  result = gimple_assign_lhs (use_stmt);
 	  use_operand_p use_p;
@@ -2683,23 +2700,30 @@  convert_mult_to_fma_1 (tree mul_result,
 
 	  use_stmt = neguse_stmt;
 	  gsi = gsi_for_stmt (use_stmt);
-	  use_code = gimple_assign_rhs_code (use_stmt);
 	  negate_p = true;
 	}
 
-      if (gimple_assign_rhs1 (use_stmt) == result)
-	{
-	  addop = gimple_assign_rhs2 (use_stmt);
-	  /* a * b - c -> a * b + (-c)  */
-	  if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
-	    addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
-	}
+      tree cond, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops))
+	gcc_unreachable ();
+      addop = ops[0] == result ? ops[1] : ops[0];
+
+      internal_fn ifn;
+      if (cond)
+	ifn = fused_cond_internal_fn (code);
       else
 	{
-	  addop = gimple_assign_rhs1 (use_stmt);
-	  /* a - b * c -> (-b) * c + a */
-	  if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
-	    negate_p = !negate_p;
+	  ifn = IFN_FMA;
+	  if (code == MINUS_EXPR)
+	    {
+	      if (ops[0] == result)
+		/* a * b - c -> a * b + (-c)  */
+		addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
+	      else
+		/* a - b * c -> (-b) * c + a */
+		negate_p = !negate_p;
+	    }
 	}
 
       if (negate_p)
@@ -2707,8 +2731,13 @@  convert_mult_to_fma_1 (tree mul_result,
 
       if (seq)
 	gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
-      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
-      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));
+
+      if (ifn == IFN_FMA)
+	fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
+      else
+	fma_stmt = gimple_build_call_internal (ifn, 4, cond, addop,
+					       mulop1, op2);
+      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));
       gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt));
       gsi_replace (&gsi, fma_stmt, true);
       /* Valueize aggressively so that we generate FMS, FNMA and FNMS
@@ -2891,7 +2920,6 @@  convert_mult_to_fma (gimple *mul_stmt, t
      as an addition.  */
   FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)
     {
-      enum tree_code use_code;
       tree result = mul_result;
       bool negate_p = false;
 
@@ -2912,13 +2940,9 @@  convert_mult_to_fma (gimple *mul_stmt, t
       if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
 	return false;
 
-      if (!is_gimple_assign (use_stmt))
-	return false;
-
-      use_code = gimple_assign_rhs_code (use_stmt);
-
       /* A negate on the multiplication leads to FNMA.  */
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+	  && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
 	{
 	  ssa_op_iter iter;
 	  use_operand_p usep;
@@ -2940,17 +2964,19 @@  convert_mult_to_fma (gimple *mul_stmt, t
 	  use_stmt = neguse_stmt;
 	  if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
 	    return false;
-	  if (!is_gimple_assign (use_stmt))
-	    return false;
 
-	  use_code = gimple_assign_rhs_code (use_stmt);
 	  negate_p = true;
 	}
 
-      switch (use_code)
+      tree cond, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops))
+	return false;
+
+      switch (code)
 	{
 	case MINUS_EXPR:
-	  if (gimple_assign_rhs2 (use_stmt) == result)
+	  if (ops[1] == result)
 	    negate_p = !negate_p;
 	  break;
 	case PLUS_EXPR:
@@ -2960,47 +2986,52 @@  convert_mult_to_fma (gimple *mul_stmt, t
 	  return false;
 	}
 
-      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed
-	 by a MULT_EXPR that we'll visit later, we might be able to
-	 get a more profitable match with fnma.
+      if (cond)
+	{
+	  /* The multiplication must be the second operand.  */
+	  if (cond == result || ops[0] == result)
+	    return false;
+	  internal_fn ifn = fused_cond_internal_fn (code);
+	  if (!direct_internal_fn_supported_p (ifn, type, opt_type))
+	    return false;
+	}
+
+      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
+	 we'll visit later, we might be able to get a more profitable
+	 match with fnma.
 	 OTOH, if we don't, a negate / fma pair has likely lower latency
 	 that a mult / subtract pair.  */
-      if (use_code == MINUS_EXPR && !negate_p
-	  && gimple_assign_rhs1 (use_stmt) == result
+      if (code == MINUS_EXPR
+	  && !negate_p
+	  && ops[0] == result
 	  && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)
-	  && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))
+	  && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)
+	  && TREE_CODE (ops[1]) == SSA_NAME
+	  && has_single_use (ops[1]))
 	{
-	  tree rhs2 = gimple_assign_rhs2 (use_stmt);
-
-	  if (TREE_CODE (rhs2) == SSA_NAME)
-	    {
-	      gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);
-	      if (has_single_use (rhs2)
-		  && is_gimple_assign (stmt2)
-		  && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
-	      return false;
-	    }
+	  gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);
+	  if (is_gimple_assign (stmt2)
+	      && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
+	    return false;
 	}
 
-      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);
-      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);
       /* We can't handle a * b + a * b.  */
-      if (use_rhs1 == use_rhs2)
+      if (ops[0] == ops[1])
 	return false;
       /* If deferring, make sure we are not looking at an instruction that
 	 wouldn't have existed if we were not.  */
       if (state->m_deferring_p
-	  && (state->m_mul_result_set.contains (use_rhs1)
-	      || state->m_mul_result_set.contains (use_rhs2)))
+	  && (state->m_mul_result_set.contains (ops[0])
+	      || state->m_mul_result_set.contains (ops[1])))
 	return false;
 
       if (check_defer)
 	{
-	  tree use_lhs = gimple_assign_lhs (use_stmt);
+	  tree use_lhs = gimple_get_lhs (use_stmt);
 	  if (state->m_last_result)
 	    {
-	      if (use_rhs2 == state->m_last_result
-		  || use_rhs1 == state->m_last_result)
+	      if (ops[1] == state->m_last_result
+		  || ops[0] == state->m_last_result)
 		defer = true;
 	      else
 		defer = false;
@@ -3009,12 +3040,12 @@  convert_mult_to_fma (gimple *mul_stmt, t
 	    {
 	      gcc_checking_assert (!state->m_initial_phi);
 	      gphi *phi;
-	      if (use_rhs1 == result)
-		phi = result_of_phi (use_rhs2);
+	      if (ops[0] == result)
+		phi = result_of_phi (ops[1]);
 	      else
 		{
-		  gcc_assert (use_rhs2 == result);
-		  phi = result_of_phi (use_rhs1);
+		  gcc_assert (ops[1] == result);
+		  phi = result_of_phi (ops[0]);
 		}
 
 	      if (phi)
Index: gcc/genmatch.c
===================================================================
--- gcc/genmatch.c	2018-05-16 10:23:03.590853492 +0100
+++ gcc/genmatch.c	2018-05-16 10:23:03.887838686 +0100
@@ -485,6 +485,10 @@  commutative_op (id_base *id)
       case CFN_FNMS:
 	return 0;
 
+      case CFN_COND_FMA_REV:
+      case CFN_COND_FNMA_REV:
+	return 2;
+
       default:
 	return -1;
       }
Index: gcc/config/aarch64/iterators.md
===================================================================
--- gcc/config/aarch64/iterators.md	2018-05-16 10:23:03.590853492 +0100
+++ gcc/config/aarch64/iterators.md	2018-05-16 10:23:03.886838736 +0100
@@ -449,6 +449,8 @@  (define_c_enum "unspec"
     UNSPEC_COND_AND	; Used in aarch64-sve.md.
     UNSPEC_COND_ORR	; Used in aarch64-sve.md.
     UNSPEC_COND_EOR	; Used in aarch64-sve.md.
+    UNSPEC_COND_FMLA	; Used in aarch64-sve.md.
+    UNSPEC_COND_FMLS	; Used in aarch64-sve.md.
     UNSPEC_COND_LT	; Used in aarch64-sve.md.
     UNSPEC_COND_LE	; Used in aarch64-sve.md.
     UNSPEC_COND_EQ	; Used in aarch64-sve.md.
@@ -1499,14 +1501,16 @@  (define_int_iterator UNPACK_UNSIGNED [UN
 
 (define_int_iterator MUL_HIGHPART [UNSPEC_SMUL_HIGHPART UNSPEC_UMUL_HIGHPART])
 
-(define_int_iterator SVE_COND_INT_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB
-				      UNSPEC_COND_SMAX UNSPEC_COND_UMAX
-				      UNSPEC_COND_SMIN UNSPEC_COND_UMIN
-				      UNSPEC_COND_AND
-				      UNSPEC_COND_ORR
-				      UNSPEC_COND_EOR])
+(define_int_iterator SVE_COND_INT2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB
+				       UNSPEC_COND_SMAX UNSPEC_COND_UMAX
+				       UNSPEC_COND_SMIN UNSPEC_COND_UMIN
+				       UNSPEC_COND_AND
+				       UNSPEC_COND_ORR
+				       UNSPEC_COND_EOR])
 
-(define_int_iterator SVE_COND_FP_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB])
+(define_int_iterator SVE_COND_FP2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB])
+
+(define_int_iterator SVE_COND_FP3_OP [UNSPEC_COND_FMLA UNSPEC_COND_FMLS])
 
 (define_int_iterator SVE_COND_FP_CMP [UNSPEC_COND_LT UNSPEC_COND_LE
 				      UNSPEC_COND_EQ UNSPEC_COND_NE
@@ -1543,7 +1547,9 @@  (define_int_attr optab [(UNSPEC_ANDF "an
 			(UNSPEC_COND_UMIN "umin")
 			(UNSPEC_COND_AND "and")
 			(UNSPEC_COND_ORR "ior")
-			(UNSPEC_COND_EOR "xor")])
+			(UNSPEC_COND_EOR "xor")
+			(UNSPEC_COND_FMLA "fma_rev")
+			(UNSPEC_COND_FMLS "fnma_rev")])
 
 (define_int_attr  maxmin_uns [(UNSPEC_UMAXV "umax")
 			      (UNSPEC_UMINV "umin")
@@ -1762,4 +1768,6 @@  (define_int_attr sve_int_op [(UNSPEC_CON
 			     (UNSPEC_COND_EOR "eor")])
 
 (define_int_attr sve_fp_op [(UNSPEC_COND_ADD "fadd")
-			    (UNSPEC_COND_SUB "fsub")])
+			    (UNSPEC_COND_SUB "fsub")
+			    (UNSPEC_COND_FMLA "fmla")
+			    (UNSPEC_COND_FMLS "fmls")])
Index: gcc/config/aarch64/aarch64-sve.md
===================================================================
--- gcc/config/aarch64/aarch64-sve.md	2018-05-16 10:23:03.590853492 +0100
+++ gcc/config/aarch64/aarch64-sve.md	2018-05-16 10:23:03.883838885 +0100
@@ -1764,7 +1764,7 @@  (define_insn "cond_<optab><mode>"
 	  [(match_operand:<VPRED> 1 "register_operand" "Upl")
 	   (match_operand:SVE_I 2 "register_operand" "0")
 	   (match_operand:SVE_I 3 "register_operand" "w")]
-	  SVE_COND_INT_OP))]
+	  SVE_COND_INT2_OP))]
   "TARGET_SVE"
   "<sve_int_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>"
 )
@@ -2543,11 +2543,23 @@  (define_insn "cond_<optab><mode>"
 	  [(match_operand:<VPRED> 1 "register_operand" "Upl")
 	   (match_operand:SVE_F 2 "register_operand" "0")
 	   (match_operand:SVE_F 3 "register_operand" "w")]
-	  SVE_COND_FP_OP))]
+	  SVE_COND_FP2_OP))]
   "TARGET_SVE"
   "<sve_fp_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>"
 )
 
+(define_insn "cond_<optab><mode>"
+  [(set (match_operand:SVE_F 0 "register_operand" "=w")
+	(unspec:SVE_F
+	  [(match_operand:<VPRED> 1 "register_operand" "Upl")
+	   (match_operand:SVE_F 2 "register_operand" "0")
+	   (match_operand:SVE_F 3 "register_operand" "w")
+	   (match_operand:SVE_F 4 "register_operand" "w")]
+	  SVE_COND_FP3_OP))]
+  "TARGET_SVE"
+  "<sve_fp_op>\t%0.<Vetype>, %1/m, %3.<Vetype>, %4.<Vetype>"
+)
+
 ;; Shift an SVE vector left and insert a scalar into element 0.
 (define_insn "vec_shl_insert_<mode>"
   [(set (match_operand:SVE_ALL 0 "register_operand" "=w, w")
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c	2018-05-16 10:23:03.888838636 +0100
@@ -0,0 +1,18 @@ 
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+double
+f (double *restrict a, double *restrict b, int *lookup)
+{
+  double res = 0.0;
+  for (int i = 0; i < 512; ++i)
+    res += a[lookup[i]] * b[i];
+  return res;
+}
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */
+/* Check that the vector instructions are the only instructions.  */
+/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */
+/* { dg-final { scan-assembler-not {\tfadd\t} } } */
+/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */
+/* { dg-final { scan-assembler-not {\tsel\t} } } */
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c	2018-05-16 10:23:03.888838636 +0100
@@ -0,0 +1,17 @@ 
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)						\
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)		\
+  {								\
+    TYPE sum = 0;						\
+    for (int i = 0; i < count; ++i)				\
+      sum += x[i] * y[i];					\
+    return sum;							\
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */
Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c
===================================================================
--- /dev/null	2018-04-20 16:19:46.369131350 +0100
+++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c	2018-05-16 10:23:03.889838586 +0100
@@ -0,0 +1,17 @@ 
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)						\
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)		\
+  {								\
+    TYPE sum = 0;						\
+    for (int i = 0; i < count; ++i)				\
+      sum -= x[i] * y[i];					\
+    return sum;							\
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */