slp: support complex FMS and complex FMS conjugate
authorTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:59:12 +0000 (20:59 +0000)
committerTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:59:12 +0000 (20:59 +0000)
This adds support for FMS and FMS conjugated to the slp pattern matcher.

Example of matches:

#include <stdio.h>
#include <complex.h>

#define N 200
#define ROT
#define TYPE float
#define TYPE2 float

void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] -=  a[i] * (b[i] ROT);
    }
}

void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] -=  conjf (a[i]) * (b[i]);
    }
}

void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] -=  a[i] * conjf (b[i] ROT);
    }
}

void caxpy_sub(double complex * restrict y, double complex * restrict x, size_t N, double complex f) {
  for (size_t i = 0; i < N; ++i)
    y[i] -= x[i]* f;
}

gcc/ChangeLog:

* internal-fn.def (COMPLEX_FMS, COMPLEX_FMS_CONJ): New.
* optabs.def (cmls_optab, cmls_conj_optab): New.
* doc/md.texi: Document them.
* tree-vect-slp-patterns.c (class complex_fms_pattern,
complex_fms_pattern::matches, complex_fms_pattern::recognize,
complex_fms_pattern::build): New.

gcc/doc/md.texi
gcc/internal-fn.def
gcc/optabs.def
gcc/tree-vect-slp-patterns.c

index 49a1ce045b19a2a1d49a1fedbacd81876e317c04..e3686dbfe6175505906bca11c3c57b85ad474fa9 100644 (file)
@@ -6247,6 +6247,51 @@ The operation is only supported for vector modes @var{m}.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmls@var{m}4} instruction pattern
+@item @samp{cmls@var{m}4}
+Perform a vector multiply and subtract that is semantically the same as
+a multiply and subtract of complex numbers.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] -= a[i] * b[i];
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmls_conj@var{m}4} instruction pattern
+@item @samp{cmls_conj@var{m}4}
+Perform a vector multiply by conjugate and subtract that is semantically
+the same as a multiply and subtract of complex numbers where the second
+multiply arguments is conjugated.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] -= a[i] * conj (b[i]);
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{cmul@var{m}4} instruction pattern
 @item @samp{cmul@var{m}4}
 Perform a vector multiply that is semantically the same as multiply of
index 020b586bc656b2691c6151d727f58601d4eec80f..daeace7a34ee950fdf724f87601478f3ec5378f7 100644 (file)
@@ -290,6 +290,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
 DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary)
 
 /* Unary integer ops.  */
 DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary)
index cecd1b61a1f1f7ae58e182eb789e1a3551a429ba..b192a9d070b8aa72e5676b2eaa020b5bdd7ffcc8 100644 (file)
@@ -296,6 +296,8 @@ OPTAB_D (cmul_optab, "cmul$a3")
 OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
 OPTAB_D (cmla_optab, "cmla$a4")
 OPTAB_D (cmla_conj_optab, "cmla_conj$a4")
+OPTAB_D (cmls_optab, "cmls$a4")
+OPTAB_D (cmls_conj_optab, "cmls_conj$a4")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
index bd632e01fb851761ad9e43ffaa1aeaf152176c47..8065a58065f742a6fe3a76f5fb6b6965b716dbf0 100644 (file)
@@ -830,7 +830,7 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
         variants to be sure.  This needs to be fixed in the mid-end so
         this part can be simpler.  */
       kind = linear_loads_p (perm_cache, right_op[0]).first;
-      if (!((kind == PERM_ODDODD
+      if (!((is_eq_or_top (linear_loads_p (perm_cache, right_op[0]), PERM_ODDODD)
           && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
                             PERM_ODDEVEN))
          || (kind == PERM_ODDEVEN
@@ -863,6 +863,7 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
     {
       if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD)
        return false;
+      return true;
     }
   else if (!neg_first)
     *conj_first_operand = true;
@@ -1265,6 +1266,185 @@ complex_fma_pattern::build (vec_info *vinfo)
   complex_pattern::build (vinfo);
 }
 
+/*******************************************************************************
+ * complex_fms_pattern class
+ ******************************************************************************/
+
+class complex_fms_pattern : public complex_pattern
+{
+  protected:
+    complex_fms_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+      : complex_pattern (node, m_ops, ifn)
+    {
+      this->m_num_args = 3;
+    }
+
+  public:
+    void build (vec_info *);
+    static internal_fn
+    matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
+            vec<slp_tree> *);
+
+    static vect_pattern*
+    recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
+
+    static vect_pattern*
+    mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+    {
+      return new complex_fms_pattern (node, m_ops, ifn);
+    }
+};
+
+
+/* Pattern matcher for trying to match complex multiply and accumulate
+   and multiply and subtract patterns in SLP tree.
+   If the operation matches then IFN is set to the operation it matched and
+   the arguments to the two replacement statements are put in m_ops.
+
+   If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
+
+   This function matches the patterns shaped as:
+
+   double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
+   double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
+
+   c[i] = c[i] - ax;
+   c[i+1] = c[i+1] + bx;
+
+   If a match occurred then TRUE is returned, else FALSE.  The initial match is
+   expected to be in OP1 and the initial match operands in args0.  */
+
+internal_fn
+complex_fms_pattern::matches (complex_operation_t op,
+                             slp_tree_to_load_perm_map_t *perm_cache,
+                             slp_tree * ref_node, vec<slp_tree> *ops)
+{
+  internal_fn ifn = IFN_LAST;
+
+  /* Find the two components.  We match Complex MUL first which reduces the
+     amount of work this pattern has to do.  After that we just match the
+     head node and we're done.:
+
+     * FMS: - +.  */
+  slp_tree child = NULL;
+
+  /* We need to ignore the two_operands nodes that may also match,
+     for that we can check if they have any scalar statements and also
+     check that it's not a permute node as we're looking for a normal
+     PLUS_EXPR operation.  */
+  if (op != PLUS_MINUS)
+    return IFN_LAST;
+
+  child = SLP_TREE_CHILDREN ((*ops)[1])[1];
+  if (vect_detect_pair_op (child) != MINUS_PLUS)
+    return IFN_LAST;
+
+  /* First two nodes must be a multiply.  */
+  auto_vec<slp_tree> muls;
+  if (vect_match_call_complex_mla (child, 0) != MULT_MULT
+      || vect_match_call_complex_mla (child, 1, &muls) != MULT_MULT)
+    return IFN_LAST;
+
+  /* Now operand2+4 may lead to another expression.  */
+  auto_vec<slp_tree> left_op, right_op;
+  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
+  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
+
+  bool is_neg = vect_normalize_conj_loc (left_op);
+
+  child = SLP_TREE_CHILDREN ((*ops)[1])[0];
+  bool conj_first_operand = false;
+  if (!vect_validate_multiplication (perm_cache, right_op, left_op, false,
+                                    &conj_first_operand, true))
+    return IFN_LAST;
+
+  if (!is_neg)
+    ifn = IFN_COMPLEX_FMS;
+  else if (is_neg)
+    ifn = IFN_COMPLEX_FMS_CONJ;
+
+  if (!vect_pattern_validate_optab (ifn, *ref_node))
+    return IFN_LAST;
+
+  ops->truncate (0);
+  ops->create (4);
+
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]).first;
+  if (kind == PERM_EVENODD)
+    {
+      ops->quick_push (child);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (left_op[1]);
+    }
+  else if (kind == PERM_TOP)
+    {
+      ops->quick_push (child);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (left_op[0]);
+    }
+  else if (kind == PERM_EVENEVEN && !is_neg)
+    {
+      ops->quick_push (child);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (left_op[0]);
+    }
+  else
+    {
+      ops->quick_push (child);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (left_op[1]);
+    }
+
+  return ifn;
+}
+
+/* Attempt to recognize a complex mul pattern.  */
+
+vect_pattern*
+complex_fms_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
+                               slp_tree *node)
+{
+  auto_vec<slp_tree> ops;
+  complex_operation_t op
+    = vect_detect_pair_op (*node, true, &ops);
+  internal_fn ifn
+    = complex_fms_pattern::matches (op, perm_cache, node, &ops);
+  if (ifn == IFN_LAST)
+    return NULL;
+
+  return new complex_fms_pattern (node, &ops, ifn);
+}
+
+/* Perform a replacement of the detected complex mul pattern with the new
+   instruction sequences.  */
+
+void
+complex_fms_pattern::build (vec_info *vinfo)
+{
+  slp_tree node;
+  unsigned i;
+  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
+    vect_free_slp_tree (node);
+
+  SLP_TREE_CHILDREN (*this->m_node).release ();
+  SLP_TREE_CHILDREN (*this->m_node).create (3);
+
+  /* First re-arrange the children.  */
+  SLP_TREE_CHILDREN (*this->m_node).quick_push (this->m_ops[0]);
+  SLP_TREE_CHILDREN (*this->m_node).quick_push (this->m_ops[1]);
+  SLP_TREE_CHILDREN (*this->m_node).quick_push (
+    vect_build_combine_node (this->m_ops[2], this->m_ops[3], *this->m_node));
+  SLP_TREE_REF_COUNT (this->m_ops[0])++;
+  SLP_TREE_REF_COUNT (this->m_ops[1])++;
+
+  /* And then rewrite the node itself.  */
+  complex_pattern::build (vinfo);
+}
+
 /*******************************************************************************
  * Pattern matching definitions
  ******************************************************************************/