nir: Add a pass for selectively lowering variables to scratch space
[mesa.git] / src / compiler / nir / nir_loop_analyze.c
index 89b8aab9ebf394a0c8d0416440ac60efbc2d4e44..781dac27bb7188d9ad3e53daee077f8acd17aa30 100644 (file)
@@ -480,9 +480,12 @@ find_array_access_via_induction(loop_info_state *state,
          *array_index_out = array_index;
 
       nir_deref_instr *parent = nir_deref_instr_parent(d);
-      assert(glsl_type_is_array_or_matrix(parent->type));
-
-      return glsl_get_length(parent->type);
+      if (glsl_type_is_array_or_matrix(parent->type)) {
+         return glsl_get_length(parent->type);
+      } else {
+         assert(glsl_type_is_vector(parent->type));
+         return glsl_get_vector_elements(parent->type);
+      }
    }
 
    return 0;
@@ -659,7 +662,8 @@ test_iterations(int32_t iter_int, nir_const_value *step,
 static int
 calculate_iterations(nir_const_value *initial, nir_const_value *step,
                      nir_const_value *limit, nir_loop_variable *alu_def,
-                     nir_alu_instr *cond_alu, bool limit_rhs, bool invert_cond)
+                     nir_alu_instr *cond_alu, nir_op alu_op, bool limit_rhs,
+                     bool invert_cond)
 {
    assert(initial != NULL && step != NULL && limit != NULL);
 
@@ -674,10 +678,10 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step,
    nir_alu_type induction_base_type =
       nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type);
    if (induction_base_type == nir_type_int || induction_base_type == nir_type_uint) {
-      assert(nir_alu_type_get_base_type(nir_op_infos[cond_alu->op].input_types[1]) == nir_type_int ||
-             nir_alu_type_get_base_type(nir_op_infos[cond_alu->op].input_types[1]) == nir_type_uint);
+      assert(nir_alu_type_get_base_type(nir_op_infos[alu_op].input_types[1]) == nir_type_int ||
+             nir_alu_type_get_base_type(nir_op_infos[alu_op].input_types[1]) == nir_type_uint);
    } else {
-      assert(nir_alu_type_get_base_type(nir_op_infos[cond_alu->op].input_types[0]) ==
+      assert(nir_alu_type_get_base_type(nir_op_infos[alu_op].input_types[0]) ==
              induction_base_type);
    }
 
@@ -701,7 +705,7 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step,
       trip_offset = 1;
    }
 
-   int iter_int = get_iteration(cond_alu->op, initial, step, limit);
+   int iter_int = get_iteration(alu_op, initial, step, limit);
 
    /* If iter_int is negative the loop is ill-formed or is the conditional is
     * unsigned with a huge iteration count so don't bother going any further.
@@ -724,7 +728,7 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step,
    for (int bias = -1; bias <= 1; bias++) {
       const int iter_bias = iter_int + bias;
 
-      if (test_iterations(iter_bias, step, limit, cond_alu->op, bit_size,
+      if (test_iterations(iter_bias, step, limit, alu_op, bit_size,
                           induction_base_type, initial,
                           limit_rhs, invert_cond)) {
          return iter_bias > 0 ? iter_bias - trip_offset : iter_bias;
@@ -791,6 +795,70 @@ get_induction_and_limit_vars(nir_alu_instr *alu, nir_loop_variable **ind,
    return limit_rhs;
 }
 
+static void
+try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
+                                 nir_loop_variable **ind,
+                                 nir_loop_variable **limit,
+                                 bool *limit_rhs,
+                                 loop_info_state *state)
+{
+   assert((*alu)->op == nir_op_ieq || (*alu)->op == nir_op_inot);
+
+   nir_ssa_def *iand_def = (*alu)->src[0].src.ssa;
+
+   if ((*alu)->op == nir_op_ieq) {
+      nir_ssa_def *zero_def = (*alu)->src[1].src.ssa;
+
+      if (iand_def->parent_instr->type != nir_instr_type_alu ||
+          zero_def->parent_instr->type != nir_instr_type_load_const) {
+
+         /* Maybe we had it the wrong way, flip things around */
+         iand_def = (*alu)->src[1].src.ssa;
+         zero_def = (*alu)->src[0].src.ssa;
+
+         /* If we still didn't find what we need then return */
+         if (zero_def->parent_instr->type != nir_instr_type_load_const)
+            return;
+      }
+
+      /* If the loop is not breaking on (x && y) == 0 then return */
+      nir_const_value zero =
+         nir_instr_as_load_const(zero_def->parent_instr)->value;
+      if (zero.i32[0] != 0)
+         return;
+   }
+
+   if (iand_def->parent_instr->type != nir_instr_type_alu)
+      return;
+
+   nir_alu_instr *iand = nir_instr_as_alu(iand_def->parent_instr);
+   if (iand->op != nir_op_iand)
+      return;
+
+   /* Check if iand src is a terminator condition and try get induction var
+    * and trip limit var.
+    */
+   nir_ssa_def *src = iand->src[0].src.ssa;
+   if (src->parent_instr->type == nir_instr_type_alu) {
+      *alu = nir_instr_as_alu(src->parent_instr);
+      if (is_supported_terminator_condition(*alu))
+         *limit_rhs = get_induction_and_limit_vars(*alu, ind, limit, state);
+   }
+
+   /* Try the other iand src if needed */
+   if (*ind == NULL || (*ind && (*ind)->type != basic_induction) ||
+       !is_var_constant(*limit)) {
+      src = iand->src[1].src.ssa;
+      if (src->parent_instr->type == nir_instr_type_alu) {
+         nir_alu_instr *tmp_alu = nir_instr_as_alu(src->parent_instr);
+         if (is_supported_terminator_condition(tmp_alu)) {
+            *alu = tmp_alu;
+            *limit_rhs = get_induction_and_limit_vars(*alu, ind, limit, state);
+         }
+      }
+   }
+}
+
 /* Run through each of the terminators of the loop and try to infer a possible
  * trip-count. We need to check them all, and set the lowest trip-count as the
  * trip-count of our loop. If one of the terminators has an undecidable
@@ -818,16 +886,37 @@ find_trip_count(loop_info_state *state)
       }
 
       nir_alu_instr *alu = nir_instr_as_alu(terminator->conditional_instr);
-      if (!is_supported_terminator_condition(alu)) {
-         trip_count_known = false;
-         continue;
-      }
+      nir_op alu_op = alu->op;
 
-      nir_loop_variable *basic_ind;
+      bool limit_rhs;
+      nir_loop_variable *basic_ind = NULL;
       nir_loop_variable *limit;
-      bool limit_rhs = get_induction_and_limit_vars(alu, &basic_ind, &limit,
-                                                    state);
-      terminator->induction_rhs = !limit_rhs;
+      if (alu->op == nir_op_inot || alu->op == nir_op_ieq) {
+         nir_alu_instr *new_alu = alu;
+         try_find_trip_count_vars_in_iand(&new_alu, &basic_ind, &limit,
+                                          &limit_rhs, state);
+
+         /* The loop is exiting on (x && y) == 0 so we need to get the
+          * inverse of x or y (i.e. which ever contained the induction var) in
+          * order to compute the trip count.
+          */
+         if (basic_ind && basic_ind->type == basic_induction) {
+            alu = new_alu;
+            alu_op = inverse_comparison(alu);
+            trip_count_known = false;
+            terminator->exact_trip_count_unknown = true;
+         }
+      }
+
+      if (!basic_ind) {
+         if (!is_supported_terminator_condition(alu)) {
+            trip_count_known = false;
+            continue;
+         }
+
+         limit_rhs = get_induction_and_limit_vars(alu, &basic_ind, &limit,
+                                                  state);
+      }
 
       /* The comparison has to have a basic induction variable for us to be
        * able to find trip counts.
@@ -837,6 +926,8 @@ find_trip_count(loop_info_state *state)
          continue;
       }
 
+      terminator->induction_rhs = !limit_rhs;
+
       /* Attempt to find a constant limit for the loop */
       nir_const_value limit_val;
       if (is_var_constant(limit)) {
@@ -874,7 +965,7 @@ find_trip_count(loop_info_state *state)
       int iterations = calculate_iterations(&initial_val, &step_val,
                                             &limit_val,
                                             basic_ind->ind->alu_def, alu,
-                                            limit_rhs,
+                                            alu_op, limit_rhs,
                                             terminator->continue_from_then);
 
       /* Where we not able to calculate the iteration count */