slp: support complex multiply and complex multiply conjugate
authorTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:57:17 +0000 (20:57 +0000)
committerTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:57:17 +0000 (20:57 +0000)
This adds support for complex multiply and complex multiply and accumulate to
the vect pattern detector.

Example of instructions matched:

#include <stdio.h>
#include <complex.h>

#define N 200
#define ROT
#define TYPE float
#define TYPE2 float

void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  a[i] * (b[i] ROT);
    }
}

void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  conjf (a[i]) * (b[i] ROT);
    }
}

void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  a[i] * conjf (b[i] ROT);
    }
}

gcc/ChangeLog:

* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
* optabs.def (cmul_optab, cmul_conj_optab): New.
* doc/md.texi: Document them.
* tree-vect-slp-patterns.c (vect_match_call_complex_mla,
vect_normalize_conj_loc, is_eq_or_top, vect_validate_multiplication,
vect_build_combine_node, class complex_mul_pattern,
complex_mul_pattern::matches, complex_mul_pattern::recognize,
complex_mul_pattern::build): New.

gcc/doc/md.texi
gcc/internal-fn.def
gcc/optabs.def
gcc/tree-vect-slp-patterns.c

index a4435df8a1312535828ec2bcfcfcb3890786c56f..60e8c94810a3f27a4fa8d59367b0710323504e9c 100644 (file)
@@ -6202,6 +6202,50 @@ The operation is only supported for vector modes @var{m}.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmul@var{m}4} instruction pattern
+@item @samp{cmul@var{m}4}
+Perform a vector multiply that is semantically the same as multiply of
+complex numbers.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] = a[i] * b[i];
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmul_conj@var{m}4} instruction pattern
+@item @samp{cmul_conj@var{m}4}
+Perform a vector multiply by conjugate that is semantically the same as a
+multiply of complex numbers where the second multiply arguments is conjugated.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] = a[i] * conj (b[i]);
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{ffs@var{m}2} instruction pattern
 @item @samp{ffs@var{m}2}
 Store into operand 0 one plus the index of the least significant 1-bit
index 19016ce109f30f54f545162c2947a3bdca9ef02c..e3e4fe5ebadb446408b68b7408f4b86f42695b8d 100644 (file)
@@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
 DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
 
 
 /* FP scales.  */
index a69584186451270171cbcaac6e330fae7bd0dd15..fcc27d00dbadb8dd0f6793c12d45e8c5a5ab509e 100644 (file)
@@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3")
 OPTAB_D (xorsign_optab, "xorsign$F$a3")
 OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
+OPTAB_D (cmul_optab, "cmul$a3")
+OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
index c4fa269baa3ceb2142443624a6cb46f8f90533d9..dc96be51dfe2f8176621183f7e8f61da0252c770 100644 (file)
@@ -717,6 +717,374 @@ complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
   return new complex_add_pattern (node, &ops, ifn);
 }
 
+/*******************************************************************************
+ * complex_mul_pattern
+ ******************************************************************************/
+
+/* Helper function of that looks for a match in the CHILDth child of NODE.  The
+   child used is stored in RES.
+
+   If the match is successful then ARGS will contain the operands matched
+   and the complex_operation_t type is returned.  If match is not successful
+   then CMPLX_NONE is returned and ARGS is left unmodified.  */
+
+static inline complex_operation_t
+vect_match_call_complex_mla (slp_tree node, unsigned child,
+                            vec<slp_tree> *args = NULL, slp_tree *res = NULL)
+{
+  gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
+
+  slp_tree data = SLP_TREE_CHILDREN (node)[child];
+
+  if (res)
+    *res = data;
+
+  return vect_detect_pair_op (data, false, args);
+}
+
+/* Check to see if either of the trees in ARGS are a NEGATE_EXPR.  If the first
+   child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE.
+
+   If a negate is found then the values in ARGS are reordered such that the
+   negate node is always the second one and the entry is replaced by the child
+   of the negate node.  */
+
+static inline bool
+vect_normalize_conj_loc (vec<slp_tree> args, bool *neg_first_p = NULL)
+{
+  gcc_assert (args.length () == 2);
+  bool neg_found = false;
+
+  if (vect_match_expression_p (args[0], NEGATE_EXPR))
+    {
+      std::swap (args[0], args[1]);
+      neg_found = true;
+      if (neg_first_p)
+       *neg_first_p = true;
+    }
+  else if (vect_match_expression_p (args[1], NEGATE_EXPR))
+    {
+      neg_found = true;
+      if (neg_first_p)
+       *neg_first_p = false;
+    }
+
+  if (neg_found)
+    args[1] = SLP_TREE_CHILDREN (args[1])[0];
+
+  return neg_found;
+}
+
+/* Helper function to check if PERM is KIND or PERM_TOP.  */
+
+static inline bool
+is_eq_or_top (complex_load_perm_t perm, complex_perm_kinds_t kind)
+{
+  return perm.first == kind || perm.first == PERM_TOP;
+}
+
+/* Helper function that checks to see if LEFT_OP and RIGHT_OP are both MULT_EXPR
+   nodes but also that they represent an operation that is either a complex
+   multiplication or a complex multiplication by conjugated value.
+
+   Of the negation is expected to be in the first half of the tree (As required
+   by an FMS pattern) then NEG_FIRST is true.  If the operation is a conjugate
+   operation then CONJ_FIRST_OPERAND is set to indicate whether the first or
+   second operand contains the conjugate operation.  */
+
+static inline bool
+vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
+                            vec<slp_tree> left_op, vec<slp_tree> right_op,
+                            bool neg_first, bool *conj_first_operand,
+                            bool fms)
+{
+  /* The presence of a negation indicates that we have either a conjugate or a
+     rotation.  We need to distinguish which one.  */
+  *conj_first_operand = false;
+  complex_perm_kinds_t kind;
+
+  /* Complex conjugates have the negation on the imaginary part of the
+     number where rotations affect the real component.  So check if the
+     negation is on a dup of lane 1.  */
+  if (fms)
+    {
+      /* Canonicalization for fms is not consistent. So have to test both
+        variants to be sure.  This needs to be fixed in the mid-end so
+        this part can be simpler.  */
+      kind = linear_loads_p (perm_cache, right_op[0]).first;
+      if (!((kind == PERM_ODDODD
+          && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
+                            PERM_ODDEVEN))
+         || (kind == PERM_ODDEVEN
+             && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
+                            PERM_ODDODD))))
+       return false;
+    }
+  else
+    {
+      if (linear_loads_p (perm_cache, right_op[1]).first != PERM_ODDODD
+         && !is_eq_or_top (linear_loads_p (perm_cache, right_op[0]),
+                           PERM_ODDEVEN))
+       return false;
+    }
+
+  /* Deal with differences in indexes.  */
+  int index1 = fms ? 1 : 0;
+  int index2 = fms ? 0 : 1;
+
+  /* Check if the conjugate is on the second first or second operand.  The
+     order of the node with the conjugate value determines this, and the dup
+     node must be one of lane 0 of the same DR as the neg node.  */
+  kind = linear_loads_p (perm_cache, left_op[index1]).first;
+  if (kind == PERM_TOP)
+    {
+      if (linear_loads_p (perm_cache, left_op[index2]).first == PERM_EVENODD)
+       return true;
+    }
+  else if (kind == PERM_EVENODD)
+    {
+      if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD)
+       return false;
+    }
+  else if (!neg_first)
+    *conj_first_operand = true;
+  else
+    return false;
+
+  if (kind != PERM_EVENEVEN)
+    return false;
+
+  return true;
+}
+
+/* Helper function to help distinguish between a conjugate and a rotation in a
+   complex multiplication.  The operations have similar shapes but the order of
+   the load permutes are different.  This function returns TRUE when the order
+   is consistent with a multiplication or multiplication by conjugated
+   operand but returns FALSE if it's a multiplication by rotated operand.  */
+
+static inline bool
+vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
+                            vec<slp_tree> op, complex_perm_kinds_t permKind)
+{
+  /* The left node is the more common case, test it first.  */
+  if (!is_eq_or_top (linear_loads_p (perm_cache, op[0]), permKind))
+    {
+      if (!is_eq_or_top (linear_loads_p (perm_cache, op[1]), permKind))
+       return false;
+    }
+  return true;
+}
+
+/* This function combines two nodes containing only even and only odd lanes
+   together into a single node which contains the nodes in even/odd order
+   by using a lane permute.
+
+   The lanes in EVEN and ODD are duplicated 2 times inside the vectors.
+   So for a lanes = 4 EVEN contains {EVEN1, EVEN1, EVEN2, EVEN2}.
+
+   The tree REPRESENTATION is taken from the supplied REP along with the
+   vectype which must be the same between all three nodes.
+*/
+
+static slp_tree
+vect_build_combine_node (slp_tree even, slp_tree odd, slp_tree rep)
+{
+  vec<std::pair<unsigned, unsigned> > perm;
+  perm.create (SLP_TREE_LANES (rep));
+
+  for (unsigned x = 0; x < SLP_TREE_LANES (rep); x+=2)
+    {
+      perm.quick_push (std::make_pair (0, x));
+      perm.quick_push (std::make_pair (1, x+1));
+    }
+
+  slp_tree vnode = vect_create_new_slp_node (2, SLP_TREE_CODE (even));
+  SLP_TREE_CODE (vnode) = VEC_PERM_EXPR;
+  SLP_TREE_LANE_PERMUTATION (vnode) = perm;
+
+  SLP_TREE_CHILDREN (vnode).create (2);
+  SLP_TREE_CHILDREN (vnode).quick_push (even);
+  SLP_TREE_CHILDREN (vnode).quick_push (odd);
+  SLP_TREE_REF_COUNT (even)++;
+  SLP_TREE_REF_COUNT (odd)++;
+  SLP_TREE_REF_COUNT (vnode) = 1;
+
+  SLP_TREE_LANES (vnode) = SLP_TREE_LANES (rep);
+  gcc_assert (perm.length () == SLP_TREE_LANES (vnode));
+  /* Representation is set to that of the current node as the vectorizer
+     can't deal with VEC_PERMs with no representation, as would be the
+     case with invariants.  */
+  SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (rep);
+  SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (rep);
+  return vnode;
+}
+
+class complex_mul_pattern : public complex_pattern
+{
+  protected:
+    complex_mul_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+      : complex_pattern (node, m_ops, ifn)
+    {
+      this->m_num_args = 2;
+    }
+
+  public:
+    void build (vec_info *);
+    static internal_fn
+    matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
+            vec<slp_tree> *);
+
+    static vect_pattern*
+    recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
+
+    static vect_pattern*
+    mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+    {
+      return new complex_mul_pattern (node, m_ops, ifn);
+    }
+
+};
+
+/* Pattern matcher for trying to match complex multiply pattern in SLP tree
+   If the operation matches then IFN is set to the operation it matched
+   and the arguments to the two replacement statements are put in m_ops.
+
+   If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
+
+   This function matches the patterns shaped as:
+
+   double ax = (b[i+1] * a[i]);
+   double bx = (a[i+1] * b[i]);
+
+   c[i] = c[i] - ax;
+   c[i+1] = c[i+1] + bx;
+
+   If a match occurred then TRUE is returned, else FALSE.  The initial match is
+   expected to be in OP1 and the initial match operands in args0.  */
+
+internal_fn
+complex_mul_pattern::matches (complex_operation_t op,
+                             slp_tree_to_load_perm_map_t *perm_cache,
+                             slp_tree *node, vec<slp_tree> *ops)
+{
+  internal_fn ifn = IFN_LAST;
+
+  if (op != MINUS_PLUS)
+    return IFN_LAST;
+
+  slp_tree root = *node;
+  /* First two nodes must be a multiply.  */
+  auto_vec<slp_tree> muls;
+  if (vect_match_call_complex_mla (root, 0) != MULT_MULT
+      || vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT)
+    return IFN_LAST;
+
+  /* Now operand2+4 may lead to another expression.  */
+  auto_vec<slp_tree> left_op, right_op;
+  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
+  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
+
+  if (linear_loads_p (perm_cache, left_op[1]).first == PERM_ODDEVEN)
+    return IFN_LAST;
+
+  bool neg_first = false;
+  bool conj_first_operand = false;
+  bool is_neg = vect_normalize_conj_loc (right_op, &neg_first);
+
+  if (!is_neg)
+    {
+      /* A multiplication needs to multiply agains the real pair, otherwise
+        the pattern matches that of FMS.   */
+      if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN)
+         || vect_normalize_conj_loc (left_op))
+       return IFN_LAST;
+      ifn = IFN_COMPLEX_MUL;
+    }
+  else if (is_neg)
+    {
+      if (!vect_validate_multiplication (perm_cache, left_op, right_op,
+                                        neg_first, &conj_first_operand,
+                                        false))
+       return IFN_LAST;
+
+      ifn = IFN_COMPLEX_MUL_CONJ;
+    }
+
+  if (!vect_pattern_validate_optab (ifn, *node))
+    return IFN_LAST;
+
+  ops->truncate (0);
+  ops->create (3);
+
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]).first;
+  if (kind == PERM_EVENODD)
+    {
+      ops->quick_push (left_op[1]);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (left_op[0]);
+    }
+  else if (kind == PERM_TOP)
+    {
+      ops->quick_push (left_op[1]);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (left_op[0]);
+    }
+  else if (kind == PERM_EVENEVEN && !conj_first_operand)
+    {
+      ops->quick_push (left_op[0]);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (left_op[1]);
+    }
+  else
+    {
+      ops->quick_push (left_op[0]);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (left_op[1]);
+    }
+
+  return ifn;
+}
+
+/* Attempt to recognize a complex mul pattern.  */
+
+vect_pattern*
+complex_mul_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
+                               slp_tree *node)
+{
+  auto_vec<slp_tree> ops;
+  complex_operation_t op
+    = vect_detect_pair_op (*node, true, &ops);
+  internal_fn ifn
+    = complex_mul_pattern::matches (op, perm_cache, node, &ops);
+  if (ifn == IFN_LAST)
+    return NULL;
+
+  return new complex_mul_pattern (node, &ops, ifn);
+}
+
+/* Perform a replacement of the detected complex mul pattern with the new
+   instruction sequences.  */
+
+void
+complex_mul_pattern::build (vec_info *vinfo)
+{
+  slp_tree node;
+  unsigned i;
+  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
+    vect_free_slp_tree (node);
+
+  /* First re-arrange the children.  */
+  SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
+  SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
+  SLP_TREE_CHILDREN (*this->m_node)[1] =
+    vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node);
+  SLP_TREE_REF_COUNT (this->m_ops[2])++;
+
+  /* And then rewrite the node itself.  */
+  complex_pattern::build (vinfo);
+}
+
 /*******************************************************************************
  * Pattern matching definitions
  ******************************************************************************/