openmp: Compute triangular loop number of iterations at compile time
authorJakub Jelinek <jakub@redhat.com>
Mon, 22 Jun 2020 09:06:08 +0000 (11:06 +0200)
committerJakub Jelinek <jakub@redhat.com>
Mon, 22 Jun 2020 09:06:59 +0000 (11:06 +0200)
2020-06-22  Jakub Jelinek  <jakub@redhat.com>

* omp-general.c (omp_extract_for_data): For triangular loops with
all loop invariant expressions constant where the innermost loop is
executed at least once compute number of iterations at compile time.

gcc/omp-general.c

index fc889002a520c0e575689c74360033ce7721cf1b..10196f671e1697a7470c817a4ee7893edc889fa5 100644 (file)
@@ -313,6 +313,44 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
     }
 
   int cnt = fd->ordered ? fd->ordered : fd->collapse;
+  int single_nonrect = -1;
+  tree single_nonrect_count = NULL_TREE;
+  enum tree_code single_nonrect_cond_code = ERROR_MARK;
+  for (i = 1; i < cnt; i++)
+    {
+      tree n1 = gimple_omp_for_initial (for_stmt, i);
+      tree n2 = gimple_omp_for_final (for_stmt, i);
+      if (TREE_CODE (n1) == TREE_VEC)
+       {
+         if (fd->non_rect)
+           {
+             single_nonrect = -1;
+             break;
+           }
+         for (int j = i - 1; j >= 0; j--)
+           if (TREE_VEC_ELT (n1, 0) == gimple_omp_for_index (for_stmt, j))
+             {
+               single_nonrect = j;
+               break;
+             }
+         fd->non_rect = true;
+       }
+      else if (TREE_CODE (n2) == TREE_VEC)
+       {
+         if (fd->non_rect)
+           {
+             single_nonrect = -1;
+             break;
+           }
+         for (int j = i - 1; j >= 0; j--)
+           if (TREE_VEC_ELT (n2, 0) == gimple_omp_for_index (for_stmt, j))
+             {
+               single_nonrect = j;
+               break;
+             }
+         fd->non_rect = true;
+       }
+    }
   for (i = 0; i < cnt; i++)
     {
       if (i == 0
@@ -444,8 +482,90 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
 
       if (collapse_count && *collapse_count == NULL)
        {
+         if (count && integer_zerop (count))
+           continue;
+         tree n1first = NULL_TREE, n2first = NULL_TREE;
+         tree n1last = NULL_TREE, n2last = NULL_TREE;
+         tree ostep = NULL_TREE;
          if (loop->m1 || loop->m2)
-           t = NULL_TREE;
+           {
+             if (count == NULL_TREE)
+               continue;
+             if (single_nonrect == -1
+                 || (loop->m1 && TREE_CODE (loop->m1) != INTEGER_CST)
+                 || (loop->m2 && TREE_CODE (loop->m2) != INTEGER_CST))
+               {
+                 count = NULL_TREE;
+                 continue;
+               }
+             tree var = gimple_omp_for_initial (for_stmt, single_nonrect);
+             tree itype = TREE_TYPE (var);
+             tree first = gimple_omp_for_initial (for_stmt, single_nonrect);
+             t = gimple_omp_for_incr (for_stmt, single_nonrect);
+             ostep = omp_get_for_step_from_incr (loc, t);
+             t = fold_binary (MINUS_EXPR, long_long_unsigned_type_node,
+                              single_nonrect_count,
+                              build_one_cst (long_long_unsigned_type_node));
+             t = fold_convert (itype, t);
+             first = fold_convert (itype, first);
+             ostep = fold_convert (itype, ostep);
+             tree last = fold_binary (PLUS_EXPR, itype, first,
+                                      fold_binary (MULT_EXPR, itype, t,
+                                                   ostep));
+             if (TREE_CODE (first) != INTEGER_CST
+                 || TREE_CODE (last) != INTEGER_CST)
+               {
+                 count = NULL_TREE;
+                 continue;
+               }
+             if (loop->m1)
+               {
+                 tree m1 = fold_convert (itype, loop->m1);
+                 tree n1 = fold_convert (itype, loop->n1);
+                 n1first = fold_binary (PLUS_EXPR, itype,
+                                        fold_binary (MULT_EXPR, itype,
+                                                     first, m1), n1);
+                 n1last = fold_binary (PLUS_EXPR, itype,
+                                       fold_binary (MULT_EXPR, itype,
+                                                    last, m1), n1);
+               }
+             else
+               n1first = n1last = loop->n1;
+             if (loop->m2)
+               {
+                 tree n2 = fold_convert (itype, loop->n2);
+                 tree m2 = fold_convert (itype, loop->m2);
+                 n2first = fold_binary (PLUS_EXPR, itype,
+                                        fold_binary (MULT_EXPR, itype,
+                                                     first, m2), n2);
+                 n2last = fold_binary (PLUS_EXPR, itype,
+                                       fold_binary (MULT_EXPR, itype,
+                                                    last, m2), n2);
+               }
+             else
+               n2first = n2last = loop->n2;
+             n1first = fold_convert (TREE_TYPE (loop->v), n1first);
+             n2first = fold_convert (TREE_TYPE (loop->v), n2first);
+             n1last = fold_convert (TREE_TYPE (loop->v), n1last);
+             n2last = fold_convert (TREE_TYPE (loop->v), n2last);
+             t = fold_binary (loop->cond_code, boolean_type_node,
+                              n1first, n2first);
+             tree t2 = fold_binary (loop->cond_code, boolean_type_node,
+                                    n1last, n2last);
+             if (t && t2 && integer_nonzerop (t) && integer_nonzerop (t2))
+               /* All outer loop iterators have at least one inner loop
+                  iteration.  Try to compute the count at compile time.  */
+               t = NULL_TREE;
+             else if (t && t2 && integer_zerop (t) && integer_zerop (t2))
+               /* No iterations of the inner loop.  count will be set to
+                  zero cst below.  */;
+             else
+               {
+                 /* Punt (for now).  */
+                 count = NULL_TREE;
+                 continue;
+               }
+           }
          else
            t = fold_binary (loop->cond_code, boolean_type_node,
                             fold_convert (TREE_TYPE (loop->v), loop->n1),
@@ -454,8 +574,6 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
            count = build_zero_cst (long_long_unsigned_type_node);
          else if ((i == 0 || count != NULL_TREE)
                   && TREE_CODE (TREE_TYPE (loop->v)) == INTEGER_TYPE
-                  && loop->m1 == NULL_TREE
-                  && loop->m2 == NULL_TREE
                   && TREE_CONSTANT (loop->n1)
                   && TREE_CONSTANT (loop->n2)
                   && TREE_CODE (loop->step) == INTEGER_CST)
@@ -465,31 +583,89 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
              if (POINTER_TYPE_P (itype))
                itype = signed_type_for (itype);
              t = build_int_cst (itype, (loop->cond_code == LT_EXPR ? -1 : 1));
-             t = fold_build2_loc (loc, PLUS_EXPR, itype,
-                                  fold_convert_loc (loc, itype, loop->step),
-                                  t);
-             t = fold_build2_loc (loc, PLUS_EXPR, itype, t,
-                                  fold_convert_loc (loc, itype, loop->n2));
-             t = fold_build2_loc (loc, MINUS_EXPR, itype, t,
-                                  fold_convert_loc (loc, itype, loop->n1));
-             if (TYPE_UNSIGNED (itype) && loop->cond_code == GT_EXPR)
+             t = fold_build2 (PLUS_EXPR, itype,
+                              fold_convert (itype, loop->step), t);
+             tree n1 = loop->n1;
+             tree n2 = loop->n2;
+             if (loop->m1 || loop->m2)
                {
-                 tree step = fold_convert_loc (loc, itype, loop->step);
-                 t = fold_build2_loc (loc, TRUNC_DIV_EXPR, itype,
-                                      fold_build1_loc (loc, NEGATE_EXPR,
-                                                       itype, t),
-                                      fold_build1_loc (loc, NEGATE_EXPR,
-                                                       itype, step));
+                 gcc_assert (single_nonrect != -1);
+                 if (single_nonrect_cond_code == LT_EXPR)
+                   {
+                     n1 = n1first;
+                     n2 = n2first;
+                   }
+                 else
+                   {
+                     n1 = n1last;
+                     n2 = n2last;
+                   }
                }
+             t = fold_build2 (PLUS_EXPR, itype, t, fold_convert (itype, n2));
+             t = fold_build2 (MINUS_EXPR, itype, t, fold_convert (itype, n1));
+             tree step = fold_convert_loc (loc, itype, loop->step);
+             if (TYPE_UNSIGNED (itype) && loop->cond_code == GT_EXPR)
+               t = fold_build2 (TRUNC_DIV_EXPR, itype,
+                                fold_build1 (NEGATE_EXPR, itype, t),
+                                fold_build1 (NEGATE_EXPR, itype, step));
              else
-               t = fold_build2_loc (loc, TRUNC_DIV_EXPR, itype, t,
-                                    fold_convert_loc (loc, itype,
-                                                      loop->step));
-             t = fold_convert_loc (loc, long_long_unsigned_type_node, t);
-             if (count != NULL_TREE)
-               count = fold_build2_loc (loc, MULT_EXPR,
-                                        long_long_unsigned_type_node,
-                                        count, t);
+               t = fold_build2 (TRUNC_DIV_EXPR, itype, t, step);
+             tree llutype = long_long_unsigned_type_node;
+             t = fold_convert (llutype, t);
+             if (loop->m1 || loop->m2)
+               {
+                 /* t is number of iterations of inner loop at either first
+                    or last value of the outer iterator (the one with fewer
+                    iterations).
+                    Compute t2 = ((m2 - m1) * ostep) / step
+                    (for single_nonrect_cond_code GT_EXPR
+                     t2 = ((m1 - m2) * ostep) / step instead)
+                    and niters = outer_count * t
+                                 + t2 * ((outer_count - 1) * outer_count / 2)
+                  */
+                 tree m1 = loop->m1 ? loop->m1 : integer_zero_node;
+                 tree m2 = loop->m2 ? loop->m2 : integer_zero_node;
+                 m1 = fold_convert (itype, m1);
+                 m2 = fold_convert (itype, m2);
+                 tree t2;
+                 if (single_nonrect_cond_code == LT_EXPR)
+                   t2 = fold_build2 (MINUS_EXPR, itype, m2, m1);
+                 else
+                   t2 = fold_build2 (MINUS_EXPR, itype, m1, m2);
+                 t2 = fold_build2 (MULT_EXPR, itype, t2, ostep);
+                 if (TYPE_UNSIGNED (itype) && loop->cond_code == GT_EXPR)
+                   t2 = fold_build2 (TRUNC_DIV_EXPR, itype,
+                                     fold_build1 (NEGATE_EXPR, itype, t2),
+                                     fold_build1 (NEGATE_EXPR, itype, step));
+                 else
+                   t2 = fold_build2 (TRUNC_DIV_EXPR, itype, t2, step);
+                 t2 = fold_convert (llutype, t2);
+                 t = fold_build2 (MULT_EXPR, llutype, t,
+                                  single_nonrect_count);
+                 tree t3 = fold_build2 (MINUS_EXPR, llutype,
+                                        single_nonrect_count,
+                                        build_one_cst (llutype));
+                 t3 = fold_build2 (MULT_EXPR, llutype, t3,
+                                   single_nonrect_count);
+                 t3 = fold_build2 (TRUNC_DIV_EXPR, llutype, t3,
+                                   build_int_cst (llutype, 2));
+                 t2 = fold_build2 (MULT_EXPR, llutype, t2, t3);
+                 t = fold_build2 (PLUS_EXPR, llutype, t, t2);
+               }
+             if (i == single_nonrect)
+               {
+                 if (integer_zerop (t) || TREE_CODE (t) != INTEGER_CST)
+                   count = t;
+                 else
+                   {
+                     single_nonrect_count = t;
+                     single_nonrect_cond_code = loop->cond_code;
+                     if (count == NULL_TREE)
+                       count = build_one_cst (llutype);
+                   }
+               }
+             else if (count != NULL_TREE)
+               count = fold_build2 (MULT_EXPR, llutype, count, t);
              else
                count = t;
              if (TREE_CODE (count) != INTEGER_CST)