From 478e571a3eedfab198e48e8d2c8f02e491ba2c28 Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Thu, 14 Jan 2021 20:59:12 +0000 Subject: [PATCH] slp: support complex FMS and complex FMS conjugate This adds support for FMS and FMS conjugated to the slp pattern matcher. Example of matches: #include #include #define N 200 #define ROT #define TYPE float #define TYPE2 float void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] -= a[i] * (b[i] ROT); } } void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] -= conjf (a[i]) * (b[i]); } } void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] -= a[i] * conjf (b[i] ROT); } } void caxpy_sub(double complex * restrict y, double complex * restrict x, size_t N, double complex f) { for (size_t i = 0; i < N; ++i) y[i] -= x[i]* f; } gcc/ChangeLog: * internal-fn.def (COMPLEX_FMS, COMPLEX_FMS_CONJ): New. * optabs.def (cmls_optab, cmls_conj_optab): New. * doc/md.texi: Document them. * tree-vect-slp-patterns.c (class complex_fms_pattern, complex_fms_pattern::matches, complex_fms_pattern::recognize, complex_fms_pattern::build): New. --- gcc/doc/md.texi | 45 +++++++++ gcc/internal-fn.def | 2 + gcc/optabs.def | 2 + gcc/tree-vect-slp-patterns.c | 182 ++++++++++++++++++++++++++++++++++- 4 files changed, 230 insertions(+), 1 deletion(-) diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index 49a1ce045b1..e3686dbfe61 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6247,6 +6247,51 @@ The operation is only supported for vector modes @var{m}. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmls@var{m}4} instruction pattern +@item @samp{cmls@var{m}4} +Perform a vector multiply and subtract that is semantically the same as +a multiply and subtract of complex numbers. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] -= a[i] * b[i]; + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmls_conj@var{m}4} instruction pattern +@item @samp{cmls_conj@var{m}4} +Perform a vector multiply by conjugate and subtract that is semantically +the same as a multiply and subtract of complex numbers where the second +multiply arguments is conjugated. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] -= a[i] * conj (b[i]); + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{cmul@var{m}4} instruction pattern @item @samp{cmul@var{m}4} Perform a vector multiply that is semantically the same as multiply of diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 020b586bc65..daeace7a34e 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -290,6 +290,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary) /* Unary integer ops. */ DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) diff --git a/gcc/optabs.def b/gcc/optabs.def index cecd1b61a1f..b192a9d070b 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -296,6 +296,8 @@ OPTAB_D (cmul_optab, "cmul$a3") OPTAB_D (cmul_conj_optab, "cmul_conj$a3") OPTAB_D (cmla_optab, "cmla$a4") OPTAB_D (cmla_conj_optab, "cmla_conj$a4") +OPTAB_D (cmls_optab, "cmls$a4") +OPTAB_D (cmls_conj_optab, "cmls_conj$a4") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index bd632e01fb8..8065a58065f 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -830,7 +830,7 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache, variants to be sure. This needs to be fixed in the mid-end so this part can be simpler. */ kind = linear_loads_p (perm_cache, right_op[0]).first; - if (!((kind == PERM_ODDODD + if (!((is_eq_or_top (linear_loads_p (perm_cache, right_op[0]), PERM_ODDODD) && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]), PERM_ODDEVEN)) || (kind == PERM_ODDEVEN @@ -863,6 +863,7 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache, { if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD) return false; + return true; } else if (!neg_first) *conj_first_operand = true; @@ -1265,6 +1266,185 @@ complex_fma_pattern::build (vec_info *vinfo) complex_pattern::build (vinfo); } +/******************************************************************************* + * complex_fms_pattern class + ******************************************************************************/ + +class complex_fms_pattern : public complex_pattern +{ + protected: + complex_fms_pattern (slp_tree *node, vec *m_ops, internal_fn ifn) + : complex_pattern (node, m_ops, ifn) + { + this->m_num_args = 3; + } + + public: + void build (vec_info *); + static internal_fn + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *, + vec *); + + static vect_pattern* + recognize (slp_tree_to_load_perm_map_t *, slp_tree *); + + static vect_pattern* + mkInstance (slp_tree *node, vec *m_ops, internal_fn ifn) + { + return new complex_fms_pattern (node, m_ops, ifn); + } +}; + + +/* Pattern matcher for trying to match complex multiply and accumulate + and multiply and subtract patterns in SLP tree. + If the operation matches then IFN is set to the operation it matched and + the arguments to the two replacement statements are put in m_ops. + + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. The initial match is + expected to be in OP1 and the initial match operands in args0. */ + +internal_fn +complex_fms_pattern::matches (complex_operation_t op, + slp_tree_to_load_perm_map_t *perm_cache, + slp_tree * ref_node, vec *ops) +{ + internal_fn ifn = IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMS: - +. */ + slp_tree child = NULL; + + /* We need to ignore the two_operands nodes that may also match, + for that we can check if they have any scalar statements and also + check that it's not a permute node as we're looking for a normal + PLUS_EXPR operation. */ + if (op != PLUS_MINUS) + return IFN_LAST; + + child = SLP_TREE_CHILDREN ((*ops)[1])[1]; + if (vect_detect_pair_op (child) != MINUS_PLUS) + return IFN_LAST; + + /* First two nodes must be a multiply. */ + auto_vec muls; + if (vect_match_call_complex_mla (child, 0) != MULT_MULT + || vect_match_call_complex_mla (child, 1, &muls) != MULT_MULT) + return IFN_LAST; + + /* Now operand2+4 may lead to another expression. */ + auto_vec left_op, right_op; + left_op.safe_splice (SLP_TREE_CHILDREN (muls[0])); + right_op.safe_splice (SLP_TREE_CHILDREN (muls[1])); + + bool is_neg = vect_normalize_conj_loc (left_op); + + child = SLP_TREE_CHILDREN ((*ops)[1])[0]; + bool conj_first_operand = false; + if (!vect_validate_multiplication (perm_cache, right_op, left_op, false, + &conj_first_operand, true)) + return IFN_LAST; + + if (!is_neg) + ifn = IFN_COMPLEX_FMS; + else if (is_neg) + ifn = IFN_COMPLEX_FMS_CONJ; + + if (!vect_pattern_validate_optab (ifn, *ref_node)) + return IFN_LAST; + + ops->truncate (0); + ops->create (4); + + complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]).first; + if (kind == PERM_EVENODD) + { + ops->quick_push (child); + ops->quick_push (right_op[0]); + ops->quick_push (right_op[1]); + ops->quick_push (left_op[1]); + } + else if (kind == PERM_TOP) + { + ops->quick_push (child); + ops->quick_push (right_op[1]); + ops->quick_push (right_op[0]); + ops->quick_push (left_op[0]); + } + else if (kind == PERM_EVENEVEN && !is_neg) + { + ops->quick_push (child); + ops->quick_push (right_op[1]); + ops->quick_push (right_op[0]); + ops->quick_push (left_op[0]); + } + else + { + ops->quick_push (child); + ops->quick_push (right_op[1]); + ops->quick_push (right_op[0]); + ops->quick_push (left_op[1]); + } + + return ifn; +} + +/* Attempt to recognize a complex mul pattern. */ + +vect_pattern* +complex_fms_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, + slp_tree *node) +{ + auto_vec ops; + complex_operation_t op + = vect_detect_pair_op (*node, true, &ops); + internal_fn ifn + = complex_fms_pattern::matches (op, perm_cache, node, &ops); + if (ifn == IFN_LAST) + return NULL; + + return new complex_fms_pattern (node, &ops, ifn); +} + +/* Perform a replacement of the detected complex mul pattern with the new + instruction sequences. */ + +void +complex_fms_pattern::build (vec_info *vinfo) +{ + slp_tree node; + unsigned i; + FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node) + vect_free_slp_tree (node); + + SLP_TREE_CHILDREN (*this->m_node).release (); + SLP_TREE_CHILDREN (*this->m_node).create (3); + + /* First re-arrange the children. */ + SLP_TREE_CHILDREN (*this->m_node).quick_push (this->m_ops[0]); + SLP_TREE_CHILDREN (*this->m_node).quick_push (this->m_ops[1]); + SLP_TREE_CHILDREN (*this->m_node).quick_push ( + vect_build_combine_node (this->m_ops[2], this->m_ops[3], *this->m_node)); + SLP_TREE_REF_COUNT (this->m_ops[0])++; + SLP_TREE_REF_COUNT (this->m_ops[1])++; + + /* And then rewrite the node itself. */ + complex_pattern::build (vinfo); +} + /******************************************************************************* * Pattern matching definitions ******************************************************************************/ -- 2.30.2