From 31fac31800b5671d17c46108013d6fc709370ef3 Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Thu, 14 Jan 2021 20:58:12 +0000 Subject: [PATCH] slp: support complex FMA and complex FMA conjugate This adds support for FMA and FMA conjugated to the slp pattern matcher. Example of instructions matched: #include #include #define N 200 #define ROT #define TYPE float #define TYPE2 float void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += a[i] * (b[i] ROT); } } void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += conjf (a[i]) * (b[i] ROT); } } void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += a[i] * conjf (b[i] ROT); } } void caxpy_add(double complex * restrict y, double complex * restrict x, size_t N, double complex f) { for (size_t i = 0; i < N; ++i) y[i] += x[i]* f; } gcc/ChangeLog: * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ): New. * optabs.def (cmla_optab, cmla_conj_optab): New. * doc/md.texi: Document them. * tree-vect-slp-patterns.c (vect_match_call_p, class complex_fma_pattern, vect_slp_reset_pattern, complex_fma_pattern::matches, complex_fma_pattern::recognize, complex_fma_pattern::build): New. --- gcc/doc/md.texi | 45 +++++++++ gcc/internal-fn.def | 2 + gcc/optabs.def | 2 + gcc/tree-vect-slp-patterns.c | 180 +++++++++++++++++++++++++++++++++++ 4 files changed, 229 insertions(+) diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index 60e8c94810a..49a1ce045b1 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6202,6 +6202,51 @@ The operation is only supported for vector modes @var{m}. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmla@var{m}4} instruction pattern +@item @samp{cmla@var{m}4} +Perform a vector multiply and accumulate that is semantically the same as +a multiply and accumulate of complex numbers. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] += a[i] * b[i]; + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmla_conj@var{m}4} instruction pattern +@item @samp{cmla_conj@var{m}4} +Perform a vector multiply by conjugate and accumulate that is semantically +the same as a multiply and accumulate of complex numbers where the second +multiply arguments is conjugated. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] += a[i] * conj (b[i]); + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{cmul@var{m}4} instruction pattern @item @samp{cmul@var{m}4} Perform a vector multiply that is semantically the same as multiply of diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index e3e4fe5ebad..020b586bc65 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -288,6 +288,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) /* Ternary math functions. */ DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) /* Unary integer ops. */ DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) diff --git a/gcc/optabs.def b/gcc/optabs.def index fcc27d00dba..cecd1b61a1f 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -294,6 +294,8 @@ OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") OPTAB_D (cmul_optab, "cmul$a3") OPTAB_D (cmul_conj_optab, "cmul_conj$a3") +OPTAB_D (cmla_optab, "cmla$a4") +OPTAB_D (cmla_conj_optab, "cmla_conj$a4") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index dc96be51dfe..bd632e01fb8 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -325,6 +325,24 @@ vect_match_expression_p (slp_tree node, tree_code code) return true; } +/* Checks to see if the expression represented by NODE is a call to the internal + function FN. */ + +static inline bool +vect_match_call_p (slp_tree node, internal_fn fn) +{ + if (!node + || !SLP_TREE_REPRESENTATIVE (node)) + return false; + + gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node)); + if (!expr + || !gimple_call_internal_p (expr, fn)) + return false; + + return true; +} + /* Check if the given lane permute in PERMUTES matches an alternating sequence of {even odd even odd ...}. This to account for unrolled loops. Further mode there resulting permute must be linear. */ @@ -1085,6 +1103,168 @@ complex_mul_pattern::build (vec_info *vinfo) complex_pattern::build (vinfo); } +/******************************************************************************* + * complex_fma_pattern class + ******************************************************************************/ + +class complex_fma_pattern : public complex_pattern +{ + protected: + complex_fma_pattern (slp_tree *node, vec *m_ops, internal_fn ifn) + : complex_pattern (node, m_ops, ifn) + { + this->m_num_args = 3; + } + + public: + void build (vec_info *); + static internal_fn + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *, + vec *); + + static vect_pattern* + recognize (slp_tree_to_load_perm_map_t *, slp_tree *); + + static vect_pattern* + mkInstance (slp_tree *node, vec *m_ops, internal_fn ifn) + { + return new complex_fma_pattern (node, m_ops, ifn); + } +}; + +/* Helper function to "reset" a previously matched node and undo the changes + made enough so that the node is treated as an irrelevant node. */ + +static inline void +vect_slp_reset_pattern (slp_tree node) +{ + stmt_vec_info stmt_info = vect_orig_stmt (SLP_TREE_REPRESENTATIVE (node)); + STMT_VINFO_IN_PATTERN_P (stmt_info) = false; + STMT_SLP_TYPE (stmt_info) = pure_slp; + SLP_TREE_REPRESENTATIVE (node) = stmt_info; +} + +/* Pattern matcher for trying to match complex multiply and accumulate + and multiply and subtract patterns in SLP tree. + If the operation matches then IFN is set to the operation it matched and + the arguments to the two replacement statements are put in m_ops. + + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. The match is + performed after COMPLEX_MUL which would have done the majority of the work. + This function merely matches an ADD with a COMPLEX_MUL IFN. The initial + match is expected to be in OP1 and the initial match operands in args0. */ + +internal_fn +complex_fma_pattern::matches (complex_operation_t op, + slp_tree_to_load_perm_map_t * /* perm_cache */, + slp_tree *ref_node, vec *ops) +{ + internal_fn ifn = IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMA: + +. + + We need to ignore the two_operands nodes that may also match. + For that we can check if they have any scalar statements and also + check that it's not a permute node as we're looking for a normal + PLUS_EXPR operation. */ + if (op != CMPLX_NONE) + return IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMA: + + on a non-two_operands node. */ + slp_tree vnode = *ref_node; + if (SLP_TREE_LANE_PERMUTATION (vnode).exists () + || !SLP_TREE_CHILDREN (vnode).exists () + || !vect_match_expression_p (vnode, PLUS_EXPR)) + return IFN_LAST; + + slp_tree node = SLP_TREE_CHILDREN (vnode)[1]; + + if (vect_match_call_p (node, IFN_COMPLEX_MUL)) + ifn = IFN_COMPLEX_FMA; + else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ)) + ifn = IFN_COMPLEX_FMA_CONJ; + else + return IFN_LAST; + + if (!vect_pattern_validate_optab (ifn, vnode)) + return IFN_LAST; + + /* FMA matched ADD + CMUL. During the matching of CMUL the + stmt that starts the pattern is marked as being in a pattern, + namely the CMUL. When replacing this with a CFMA we have to + unmark this statement as being in a pattern. This is because + vect_mark_pattern_stmts will only mark the current stmt as being + in a pattern. Later on when the scalar stmts are examined the + old statement which is supposed to be irrelevant will point to + CMUL unless we undo the pattern relationship here. */ + vect_slp_reset_pattern (node); + ops->truncate (0); + ops->create (3); + + if (ifn == IFN_COMPLEX_FMA) + { + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); + } + else + { + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); + } + + return ifn; +} + +/* Attempt to recognize a complex mul pattern. */ + +vect_pattern* +complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, + slp_tree *node) +{ + auto_vec ops; + complex_operation_t op + = vect_detect_pair_op (*node, true, &ops); + internal_fn ifn + = complex_fma_pattern::matches (op, perm_cache, node, &ops); + if (ifn == IFN_LAST) + return NULL; + + return new complex_fma_pattern (node, &ops, ifn); +} + +/* Perform a replacement of the detected complex mul pattern with the new + instruction sequences. */ + +void +complex_fma_pattern::build (vec_info *vinfo) +{ + SLP_TREE_CHILDREN (*this->m_node).release (); + SLP_TREE_CHILDREN (*this->m_node).create (3); + SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops); + + complex_pattern::build (vinfo); +} + /******************************************************************************* * Pattern matching definitions ******************************************************************************/ -- 2.30.2