nir/loop_analyze: Refactor detection of limit vars
authorJason Ekstrand <jason@jlekstrand.net>
Mon, 24 Jun 2019 22:33:02 +0000 (17:33 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Wed, 10 Jul 2019 00:20:59 +0000 (00:20 +0000)
This commit reworks both get_induction_and_limit_vars() and
try_find_trip_count_vars_in_iand to return true on success and not
modify their output parameters on failure.  This makes their callers
significantly simpler.

Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
src/compiler/nir/nir_loop_analyze.c

index 979413c2cb026249f2c8bb81c67898d80d1988bc..c64314aa37818305a130ff581d7e2f3730923fe6 100644 (file)
@@ -810,27 +810,32 @@ is_supported_terminator_condition(nir_alu_instr *alu)
 }
 
 static bool
-get_induction_and_limit_vars(nir_alu_instr *alu, nir_loop_variable **ind,
+get_induction_and_limit_vars(nir_alu_instr *alu,
+                             nir_loop_variable **ind,
                              nir_loop_variable **limit,
+                             bool *limit_rhs,
                              loop_info_state *state)
 {
-   bool limit_rhs = true;
-
-   /* We assume that the limit is the "right" operand */
-   *ind = get_loop_var(alu->src[0].src.ssa, state);
-   *limit = get_loop_var(alu->src[1].src.ssa, state);
-
-   if ((*ind)->type != basic_induction) {
-      /* We had it the wrong way, flip things around */
-      *ind = get_loop_var(alu->src[1].src.ssa, state);
-      *limit = get_loop_var(alu->src[0].src.ssa, state);
-      limit_rhs = false;
+   nir_loop_variable *rhs, *lhs;
+   lhs = get_loop_var(alu->src[0].src.ssa, state);
+   rhs = get_loop_var(alu->src[1].src.ssa, state);
+
+   if (lhs->type == basic_induction) {
+      *ind = lhs;
+      *limit = rhs;
+      *limit_rhs = true;
+      return true;
+   } else if (rhs->type == basic_induction) {
+      *ind = rhs;
+      *limit = lhs;
+      *limit_rhs = false;
+      return true;
+   } else {
+      return false;
    }
-
-   return limit_rhs;
 }
 
-static void
+static bool
 try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
                                  nir_loop_variable **ind,
                                  nir_loop_variable **limit,
@@ -848,7 +853,7 @@ try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
 
       /* We don't handle swizzles here */
       if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0)
-         return;
+         return false;
 
       if (iand_def->parent_instr->type != nir_instr_type_alu ||
           zero_def->parent_instr->type != nir_instr_type_load_const) {
@@ -859,49 +864,49 @@ try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
 
          /* If we still didn't find what we need then return */
          if (zero_def->parent_instr->type != nir_instr_type_load_const)
-            return;
+            return false;
       }
 
       /* If the loop is not breaking on (x && y) == 0 then return */
       nir_const_value *zero =
          nir_instr_as_load_const(zero_def->parent_instr)->value;
       if (zero[0].i32 != 0)
-         return;
+         return false;
    }
 
    if (iand_def->parent_instr->type != nir_instr_type_alu)
-      return;
+      return false;
 
    nir_alu_instr *iand = nir_instr_as_alu(iand_def->parent_instr);
    if (iand->op != nir_op_iand)
-      return;
+      return false;
 
    /* We don't handle swizzles here */
    if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0)
-      return;
+      return false;
 
    /* Check if iand src is a terminator condition and try get induction var
     * and trip limit var.
     */
-   nir_ssa_def *src = iand->src[0].src.ssa;
-   if (src->parent_instr->type == nir_instr_type_alu) {
-      *alu = nir_instr_as_alu(src->parent_instr);
-      if (is_supported_terminator_condition(*alu))
-         *limit_rhs = get_induction_and_limit_vars(*alu, ind, limit, state);
-   }
-
-   /* Try the other iand src if needed */
-   if (*ind == NULL || (*ind && (*ind)->type != basic_induction) ||
-       !is_var_constant(*limit)) {
-      src = iand->src[1].src.ssa;
+   bool found_induction_var = false;
+   for (unsigned i = 0; i < 2; i++) {
+      nir_ssa_def *src = iand->src[i].src.ssa;
       if (src->parent_instr->type == nir_instr_type_alu) {
-         nir_alu_instr *tmp_alu = nir_instr_as_alu(src->parent_instr);
-         if (is_supported_terminator_condition(tmp_alu)) {
-            *alu = tmp_alu;
-            *limit_rhs = get_induction_and_limit_vars(*alu, ind, limit, state);
+         nir_alu_instr *src_alu = nir_instr_as_alu(src->parent_instr);
+         if (is_supported_terminator_condition(src_alu) &&
+             get_induction_and_limit_vars(src_alu, ind, limit,
+                                          limit_rhs, state)) {
+            *alu = src_alu;
+            found_induction_var = true;
+
+            /* If we've found one with a constant limit, stop. */
+            if (is_var_constant(*limit))
+               return true;
          }
       }
    }
+
+   return found_induction_var;
 }
 
 /* Run through each of the terminators of the loop and try to infer a possible
@@ -936,37 +941,29 @@ find_trip_count(loop_info_state *state)
       bool limit_rhs;
       nir_loop_variable *basic_ind = NULL;
       nir_loop_variable *limit;
-      if (alu->op == nir_op_inot || alu->op == nir_op_ieq) {
-         nir_alu_instr *new_alu = alu;
-         try_find_trip_count_vars_in_iand(&new_alu, &basic_ind, &limit,
-                                          &limit_rhs, state);
-
+      if ((alu->op == nir_op_inot || alu->op == nir_op_ieq) &&
+          try_find_trip_count_vars_in_iand(&alu, &basic_ind, &limit,
+                                           &limit_rhs, state)) {
          /* The loop is exiting on (x && y) == 0 so we need to get the
           * inverse of x or y (i.e. which ever contained the induction var) in
           * order to compute the trip count.
           */
-         if (basic_ind && basic_ind->type == basic_induction) {
-            alu = new_alu;
-            alu_op = inverse_comparison(alu);
-            trip_count_known = false;
-            terminator->exact_trip_count_unknown = true;
-         }
+         alu_op = inverse_comparison(alu);
+         trip_count_known = false;
+         terminator->exact_trip_count_unknown = true;
       }
 
       if (!basic_ind) {
-         if (!is_supported_terminator_condition(alu)) {
-            trip_count_known = false;
-            continue;
+         if (is_supported_terminator_condition(alu)) {
+            get_induction_and_limit_vars(alu, &basic_ind,
+                                         &limit, &limit_rhs, state);
          }
-
-         limit_rhs = get_induction_and_limit_vars(alu, &basic_ind, &limit,
-                                                  state);
       }
 
       /* The comparison has to have a basic induction variable for us to be
        * able to find trip counts.
        */
-      if (basic_ind->type != basic_induction) {
+      if (!basic_ind) {
          trip_count_known = false;
          continue;
       }