Support fused multiply-adds in fully-masked reductions
authorRichard Sandiford <richard.sandiford@linaro.org>
Thu, 12 Jul 2018 13:01:48 +0000 (13:01 +0000)
committerRichard Sandiford <rsandifo@gcc.gnu.org>
Thu, 12 Jul 2018 13:01:48 +0000 (13:01 +0000)
This patch adds support for fusing a conditional add or subtract
with a multiplication, so that we can use fused multiply-add and
multiply-subtract operations for fully-masked reductions.  E.g.
for SVE we vectorise:

  double res = 0.0;
  for (int i = 0; i < n; ++i)
    res += x[i] * y[i];

using a fully-masked loop in which the loop body has the form:

  res_1 = PHI<0(preheader), res_2(latch)>;
  avec = .MASK_LOAD (loop_mask, a)
  bvec = .MASK_LOAD (loop_mask, b)
  prod = avec * bvec;
  res_2 = .COND_ADD (loop_mask, res_1, prod, res_1);

where the last statement does the equivalent of:

  res_2 = loop_mask ? res_1 + prod : res_1;

(operating elementwise).  The point of the patch is to convert the last
two statements into:

  res_s = .COND_FMA (loop_mask, avec, bvec, res_1, res_1);

which is equivalent to:

  res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;

(again operating elementwise).

2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
    Alan Hayward  <alan.hayward@arm.com>
    David Sherwood  <david.sherwood@arm.com>

gcc/
* internal-fn.h (can_interpret_as_conditional_op_p): Declare.
* internal-fn.c (can_interpret_as_conditional_op_p): New function.
* tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
plus and minus and convert them into IFN_COND_FMA-based sequences.
(convert_mult_to_fma): Handle conditional plus and minus.

gcc/testsuite/
* gcc.dg/vect/vect-fma-2.c: New test.
* gcc.target/aarch64/sve/reduc_4.c: Likewise.
* gcc.target/aarch64/sve/reduc_6.c: Likewise.
* gcc.target/aarch64/sve/reduc_7.c: Likewise.

Co-Authored-By: Alan Hayward <alan.hayward@arm.com>
Co-Authored-By: David Sherwood <david.sherwood@arm.com>
From-SVN: r262588

gcc/ChangeLog
gcc/internal-fn.c
gcc/internal-fn.h
gcc/testsuite/ChangeLog
gcc/testsuite/gcc.dg/vect/vect-fma-2.c [new file with mode: 0644]
gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c [new file with mode: 0644]
gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c [new file with mode: 0644]
gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c [new file with mode: 0644]
gcc/tree-ssa-math-opts.c

index a12938054c7f5d1d494a4c4645238610ea06e6cc..20ed355ca124bdc675332ceace84a8e312cbb62f 100644 (file)
@@ -1,3 +1,13 @@
+2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
+           Alan Hayward  <alan.hayward@arm.com>
+           David Sherwood  <david.sherwood@arm.com>
+
+       * internal-fn.h (can_interpret_as_conditional_op_p): Declare.
+       * internal-fn.c (can_interpret_as_conditional_op_p): New function.
+       * tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
+       plus and minus and convert them into IFN_COND_FMA-based sequences.
+       (convert_mult_to_fma): Handle conditional plus and minus.
+
 2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
 
        * doc/md.texi (cond_fma, cond_fms, cond_fnma, cond_fnms): Document.
index 474a16bc6fb12df44fced67f17e74cc28faef83f..15755ea06fdaca14ebbfcbfb80e7ab18418f3a0d 100644 (file)
@@ -3333,6 +3333,62 @@ get_unconditional_internal_fn (internal_fn ifn)
     }
 }
 
+/* Return true if STMT can be interpreted as a conditional tree code
+   operation of the form:
+
+     LHS = COND ? OP (RHS1, ...) : ELSE;
+
+   operating elementwise if the operands are vectors.  This includes
+   the case of an all-true COND, so that the operation always happens.
+
+   When returning true, set:
+
+   - *COND_OUT to the condition COND, or to NULL_TREE if the condition
+     is known to be all-true
+   - *CODE_OUT to the tree code
+   - OPS[I] to operand I of *CODE_OUT
+   - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
+     condition is known to be all true.  */
+
+bool
+can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
+                                  tree_code *code_out,
+                                  tree (&ops)[3], tree *else_out)
+{
+  if (gassign *assign = dyn_cast <gassign *> (stmt))
+    {
+      *cond_out = NULL_TREE;
+      *code_out = gimple_assign_rhs_code (assign);
+      ops[0] = gimple_assign_rhs1 (assign);
+      ops[1] = gimple_assign_rhs2 (assign);
+      ops[2] = gimple_assign_rhs3 (assign);
+      *else_out = NULL_TREE;
+      return true;
+    }
+  if (gcall *call = dyn_cast <gcall *> (stmt))
+    if (gimple_call_internal_p (call))
+      {
+       internal_fn ifn = gimple_call_internal_fn (call);
+       tree_code code = conditional_internal_fn_code (ifn);
+       if (code != ERROR_MARK)
+         {
+           *cond_out = gimple_call_arg (call, 0);
+           *code_out = code;
+           unsigned int nops = gimple_call_num_args (call) - 2;
+           for (unsigned int i = 0; i < 3; ++i)
+             ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
+           *else_out = gimple_call_arg (call, nops + 1);
+           if (integer_truep (*cond_out))
+             {
+               *cond_out = NULL_TREE;
+               *else_out = NULL_TREE;
+             }
+           return true;
+         }
+      }
+  return false;
+}
+
 /* Return true if IFN is some form of load from memory.  */
 
 bool
index 7105c3bbff833ba6956e649ab7ba69209022911e..2296ca0c53903fed2d177d3d03a2046e4c2c6253 100644 (file)
@@ -196,6 +196,9 @@ extern internal_fn get_conditional_internal_fn (tree_code);
 extern internal_fn get_conditional_internal_fn (internal_fn);
 extern tree_code conditional_internal_fn_code (internal_fn);
 extern internal_fn get_unconditional_internal_fn (internal_fn);
+extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
+                                              tree_code *, tree (&)[3],
+                                              tree *);
 
 extern bool internal_load_fn_p (internal_fn);
 extern bool internal_store_fn_p (internal_fn);
index 9e5287d4742360e897a387c9403163cd19140246..8291e3d72f0e060a420e8b4d036e581c2cf209b3 100644 (file)
@@ -1,3 +1,12 @@
+2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
+           Alan Hayward  <alan.hayward@arm.com>
+           David Sherwood  <david.sherwood@arm.com>
+
+       * gcc.dg/vect/vect-fma-2.c: New test.
+       * gcc.target/aarch64/sve/reduc_4.c: Likewise.
+       * gcc.target/aarch64/sve/reduc_6.c: Likewise.
+       * gcc.target/aarch64/sve/reduc_7.c: Likewise.
+
 2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
 
        * gcc.dg/vect/vect-cond-arith-3.c: New test.
diff --git a/gcc/testsuite/gcc.dg/vect/vect-fma-2.c b/gcc/testsuite/gcc.dg/vect/vect-fma-2.c
new file mode 100644 (file)
index 0000000..20d1baf
--- /dev/null
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-additional-options "-fdump-tree-optimized -fassociative-math -fno-trapping-math -fno-signed-zeros" } */
+
+#include "tree-vect.h"
+
+#define N (VECTOR_BITS * 11 / 64 + 3)
+
+double
+dot_prod (double *x, double *y)
+{
+  double sum = 0;
+  for (int i = 0; i < N; ++i)
+    sum += x[i] * y[i];
+  return sum;
+}
+
+/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target { vect_double && { vect_fully_masked && scalar_all_fma } } } } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c b/gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c
new file mode 100644 (file)
index 0000000..eb4b231
--- /dev/null
@@ -0,0 +1,18 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+double
+f (double *restrict a, double *restrict b, int *lookup)
+{
+  double res = 0.0;
+  for (int i = 0; i < 512; ++i)
+    res += a[lookup[i]] * b[i];
+  return res;
+}
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */
+/* Check that the vector instructions are the only instructions.  */
+/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */
+/* { dg-final { scan-assembler-not {\tfadd\t} } } */
+/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */
+/* { dg-final { scan-assembler-not {\tsel\t} } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c b/gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c
new file mode 100644 (file)
index 0000000..65647c4
--- /dev/null
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)                                            \
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \
+  {                                                            \
+    TYPE sum = 0;                                              \
+    for (int i = 0; i < count; ++i)                            \
+      sum += x[i] * y[i];                                      \
+    return sum;                                                        \
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c b/gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c
new file mode 100644 (file)
index 0000000..b4b408a
--- /dev/null
@@ -0,0 +1,17 @@
+/* { dg-do compile } */
+/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
+
+#define REDUC(TYPE)                                            \
+  TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count)              \
+  {                                                            \
+    TYPE sum = 0;                                              \
+    for (int i = 0; i < count; ++i)                            \
+      sum -= x[i] * y[i];                                      \
+    return sum;                                                        \
+  }
+
+REDUC (float)
+REDUC (double)
+
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */
+/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
index 187ca5a3b019b489138bc9cd75b77be05fa76130..e32669dc944b66d749fc5bc2b0952ca630f3a878 100644 (file)
@@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
   FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)
     {
       gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);
-      enum tree_code use_code;
       tree addop, mulop1 = op1, result = mul_result;
       bool negate_p = false;
       gimple_seq seq = NULL;
@@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
       if (is_gimple_debug (use_stmt))
        continue;
 
-      use_code = gimple_assign_rhs_code (use_stmt);
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
        {
          result = gimple_assign_lhs (use_stmt);
          use_operand_p use_p;
@@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
 
          use_stmt = neguse_stmt;
          gsi = gsi_for_stmt (use_stmt);
-         use_code = gimple_assign_rhs_code (use_stmt);
          negate_p = true;
        }
 
-      if (gimple_assign_rhs1 (use_stmt) == result)
+      tree cond, else_value, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
+                                             ops, &else_value))
+       gcc_unreachable ();
+      addop = ops[0] == result ? ops[1] : ops[0];
+
+      if (code == MINUS_EXPR)
        {
-         addop = gimple_assign_rhs2 (use_stmt);
-         /* a * b - c -> a * b + (-c)  */
-         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
+         if (ops[0] == result)
+           /* a * b - c -> a * b + (-c)  */
            addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
-       }
-      else
-       {
-         addop = gimple_assign_rhs1 (use_stmt);
-         /* a - b * c -> (-b) * c + a */
-         if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
+         else
+           /* a - b * c -> (-b) * c + a */
            negate_p = !negate_p;
        }
 
@@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
 
       if (seq)
        gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
-      fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
-      gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt));
+
+      if (cond)
+       fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
+                                              op2, addop, else_value);
+      else
+       fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
+      gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));
       gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt));
       gsi_replace (&gsi, fma_stmt, true);
       /* Follow all SSA edges so that we generate FMS, FNMA and FNMS
@@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
      as an addition.  */
   FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)
     {
-      enum tree_code use_code;
       tree result = mul_result;
       bool negate_p = false;
 
@@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
       if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
        return false;
 
-      if (!is_gimple_assign (use_stmt))
-       return false;
-
-      use_code = gimple_assign_rhs_code (use_stmt);
-
       /* A negate on the multiplication leads to FNMA.  */
-      if (use_code == NEGATE_EXPR)
+      if (is_gimple_assign (use_stmt)
+         && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
        {
          ssa_op_iter iter;
          use_operand_p usep;
@@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
          use_stmt = neguse_stmt;
          if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
            return false;
-         if (!is_gimple_assign (use_stmt))
-           return false;
 
-         use_code = gimple_assign_rhs_code (use_stmt);
          negate_p = true;
        }
 
-      switch (use_code)
+      tree cond, else_value, ops[3];
+      tree_code code;
+      if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
+                                             &else_value))
+       return false;
+
+      switch (code)
        {
        case MINUS_EXPR:
-         if (gimple_assign_rhs2 (use_stmt) == result)
+         if (ops[1] == result)
            negate_p = !negate_p;
          break;
        case PLUS_EXPR:
@@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
          return false;
        }
 
-      /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed
-        by a MULT_EXPR that we'll visit later, we might be able to
-        get a more profitable match with fnma.
+      if (cond)
+       {
+         if (cond == result || else_value == result)
+           return false;
+         if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
+           return false;
+       }
+
+      /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
+        we'll visit later, we might be able to get a more profitable
+        match with fnma.
         OTOH, if we don't, a negate / fma pair has likely lower latency
         that a mult / subtract pair.  */
-      if (use_code == MINUS_EXPR && !negate_p
-         && gimple_assign_rhs1 (use_stmt) == result
+      if (code == MINUS_EXPR
+         && !negate_p
+         && ops[0] == result
          && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)
-         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type))
+         && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)
+         && TREE_CODE (ops[1]) == SSA_NAME
+         && has_single_use (ops[1]))
        {
-         tree rhs2 = gimple_assign_rhs2 (use_stmt);
-
-         if (TREE_CODE (rhs2) == SSA_NAME)
-           {
-             gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);
-             if (has_single_use (rhs2)
-                 && is_gimple_assign (stmt2)
-                 && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
-             return false;
-           }
+         gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);
+         if (is_gimple_assign (stmt2)
+             && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
+           return false;
        }
 
-      tree use_rhs1 = gimple_assign_rhs1 (use_stmt);
-      tree use_rhs2 = gimple_assign_rhs2 (use_stmt);
       /* We can't handle a * b + a * b.  */
-      if (use_rhs1 == use_rhs2)
+      if (ops[0] == ops[1])
        return false;
       /* If deferring, make sure we are not looking at an instruction that
         wouldn't have existed if we were not.  */
       if (state->m_deferring_p
-         && (state->m_mul_result_set.contains (use_rhs1)
-             || state->m_mul_result_set.contains (use_rhs2)))
+         && (state->m_mul_result_set.contains (ops[0])
+             || state->m_mul_result_set.contains (ops[1])))
        return false;
 
       if (check_defer)
        {
-         tree use_lhs = gimple_assign_lhs (use_stmt);
+         tree use_lhs = gimple_get_lhs (use_stmt);
          if (state->m_last_result)
            {
-             if (use_rhs2 == state->m_last_result
-                 || use_rhs1 == state->m_last_result)
+             if (ops[1] == state->m_last_result
+                 || ops[0] == state->m_last_result)
                defer = true;
              else
                defer = false;
@@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
            {
              gcc_checking_assert (!state->m_initial_phi);
              gphi *phi;
-             if (use_rhs1 == result)
-               phi = result_of_phi (use_rhs2);
+             if (ops[0] == result)
+               phi = result_of_phi (ops[1]);
              else
                {
-                 gcc_assert (use_rhs2 == result);
-                 phi = result_of_phi (use_rhs1);
+                 gcc_assert (ops[1] == result);
+                 phi = result_of_phi (ops[0]);
                }
 
              if (phi)